From 546120cd26118699c99a3597cb0229c91eca4369 Mon Sep 17 00:00:00 2001 From: Flavio Castelli Date: Fri, 17 Nov 2023 11:13:53 +0100 Subject: [PATCH] refactor: reduce memory usage Re-architect how the waPC policies are handled. This is required to reduce the amount of memory used by Policy Server. Signed-off-by: Flavio Castelli --- Cargo.toml | 2 + src/evaluation_context.rs | 124 +++++++ src/lib.rs | 2 +- src/policy.rs | 172 --------- src/policy_evaluator.rs | 94 +++-- src/policy_evaluator_builder.rs | 260 +++++--------- src/policy_tracing.rs | 4 +- src/runtimes/mod.rs | 21 ++ src/runtimes/wapc/callback.rs | 127 +++---- .../wapc/evaluation_context_registry.rs | 329 ++++++++++++++++++ src/runtimes/wapc/mapping.rs | 23 -- src/runtimes/wapc/mod.rs | 9 +- src/runtimes/wapc/stack.rs | 38 +- 13 files changed, 694 insertions(+), 511 deletions(-) create mode 100644 src/evaluation_context.rs delete mode 100644 src/policy.rs create mode 100644 src/runtimes/wapc/evaluation_context_registry.rs delete mode 100644 src/runtimes/wapc/mapping.rs diff --git a/Cargo.toml b/Cargo.toml index aacffbac..2eacfa5e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -65,3 +65,5 @@ assert-json-diff = "2.0" k8s-openapi = { version = "0.20.0", default-features = false, features = [ "v1_28", ] } +rstest = "0.18" +test-context = "0.1" diff --git a/src/evaluation_context.rs b/src/evaluation_context.rs new file mode 100644 index 00000000..58ea25ce --- /dev/null +++ b/src/evaluation_context.rs @@ -0,0 +1,124 @@ +use std::collections::BTreeSet; +use std::fmt; +use tokio::sync::mpsc; + +use crate::callback_requests::CallbackRequest; +use crate::policy_metadata::ContextAwareResource; + +/// A struct that holds metadata and other data that are needed when a policy +/// is being evaluated +#[derive(Clone)] +pub struct EvaluationContext { + /// The policy identifier. This is mostly relevant for Policy Server, + /// which uses the identifier provided by the user inside of the `policy.yml` + /// file + pub policy_id: String, + + /// Channel used by the synchronous world (like the `host_callback` waPC function, + /// but also Burrego for k8s context aware data), + /// to request the computation of code that can only be run inside of an + /// asynchronous block + pub callback_channel: Option>, + + /// List of ContextAwareResource the policy is granted access to. + pub ctx_aware_resources_allow_list: BTreeSet, +} + +impl EvaluationContext { + /// Checks if a policy has access to a Kubernetes resource, based on the privileges + /// that have been granted by the user + pub(crate) fn can_access_kubernetes_resource(&self, api_version: &str, kind: &str) -> bool { + let wanted_resource = ContextAwareResource { + api_version: api_version.to_string(), + kind: kind.to_string(), + }; + + self.ctx_aware_resources_allow_list + .contains(&wanted_resource) + } + + /// Copy data from another `EvaluationContext` instance + pub(crate) fn copy_from(&mut self, other: &EvaluationContext) { + if self.policy_id == other.policy_id { + // The current evaluation context is about the very same policy + // There's nothing to be done + return; + } + self.policy_id = other.policy_id.clone(); + self.callback_channel = other.callback_channel.clone(); + self.ctx_aware_resources_allow_list = other.ctx_aware_resources_allow_list.clone(); + } +} + +impl fmt::Debug for EvaluationContext { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let callback_channel = match self.callback_channel { + Some(_) => "Some(...)", + None => "None", + }; + + write!( + f, + r#"EvaluationContext {{ policy_id: "{}", callback_channel: {}, allowed_kubernetes_resources: {:?} }}"#, + self.policy_id, callback_channel, self.ctx_aware_resources_allow_list, + ) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use rstest::rstest; + + #[rstest] + #[case("nothing allowed", BTreeSet::new(), "v1", "Secret", false)] + #[case( + "try to access denied resource", + BTreeSet::from([ + ContextAwareResource{ + api_version: "v1".to_string(), + kind: "ConfigMap".to_string(), + }]), + "v1", + "Secret", + false, + )] + #[case( + "access allowed resource", + BTreeSet::from([ + ContextAwareResource{ + api_version: "v1".to_string(), + kind: "ConfigMap".to_string(), + }]), + "v1", + "ConfigMap", + true, + )] + + fn can_access_kubernetes_resource( + #[case] name: &str, + #[case] allowed_resources: BTreeSet, + #[case] api_version: &str, + #[case] kind: &str, + #[case] allowed: bool, + ) { + let ctx = EvaluationContext { + policy_id: name.to_string(), + callback_channel: None, + ctx_aware_resources_allow_list: allowed_resources, + }; + + let requested_resource = ContextAwareResource { + api_version: api_version.to_string(), + kind: kind.to_string(), + }; + + assert_eq!( + allowed, + ctx.can_access_kubernetes_resource( + &requested_resource.api_version, + &requested_resource.kind + ) + ); + } +} diff --git a/src/lib.rs b/src/lib.rs index 820eb255..d104cf75 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,7 +7,7 @@ pub mod callback_handler; pub mod callback_requests; pub mod constants; pub mod errors; -pub(crate) mod policy; +pub mod evaluation_context; pub mod policy_artifacthub; pub mod policy_evaluator; pub mod policy_evaluator_builder; diff --git a/src/policy.rs b/src/policy.rs deleted file mode 100644 index 38a23a19..00000000 --- a/src/policy.rs +++ /dev/null @@ -1,172 +0,0 @@ -use anyhow::Result; -use std::clone::Clone; -use std::collections::BTreeSet; -use std::fmt; -use tokio::sync::mpsc; - -use crate::callback_requests::CallbackRequest; -use crate::policy_metadata::ContextAwareResource; - -/// Minimal amount of information about a policy that need to -/// be always accessible at runtime. -/// -/// This struct is used extensively inside of the `host_callback` -/// function to obtain information about the policy that is invoking -/// a host waPC function, and handle the request. -#[derive(Clone, Default)] -pub struct Policy { - /// The policy identifier. This is mostly relevant for Policy Server, - /// which uses the identifier provided by the user inside of the `policy.yml` - /// file - pub id: String, - - /// This is relevant only for waPC-based policies. This is the unique ID - /// associated to the waPC policy. - /// Burrego policies have this field set to `None` - instance_id: Option, - - /// Channel used by the synchronous world (the `host_callback` waPC function), - /// to request the computation of code that can only be run inside of an - /// asynchronous block - pub callback_channel: Option>, - - /// List of ContextAwareResource the policy is granted access to. - /// Currently, this is relevant only for waPC based policies - pub ctx_aware_resources_allow_list: BTreeSet, -} - -impl fmt::Debug for Policy { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let callback_channel = match self.callback_channel { - Some(_) => "Some(...)", - None => "None", - }; - - write!( - f, - r#"Policy {{ id: "{}", instance_id: {:?}, callback_channel: {} }}"#, - self.id, self.instance_id, callback_channel, - ) - } -} - -impl PartialEq for Policy { - fn eq(&self, other: &Self) -> bool { - self.id == other.id && self.instance_id == other.instance_id - } -} - -impl Policy { - pub(crate) fn new( - id: String, - policy_id: Option, - callback_channel: Option>, - ctx_aware_resources_allow_list: Option>, - ) -> Result { - Ok(Policy { - id, - instance_id: policy_id, - callback_channel, - ctx_aware_resources_allow_list: ctx_aware_resources_allow_list.unwrap_or_default(), - }) - } - - pub(crate) fn can_access_kubernetes_resource(&self, api_version: &str, kind: &str) -> bool { - let wanted_resource = ContextAwareResource { - api_version: api_version.to_string(), - kind: kind.to_string(), - }; - - self.ctx_aware_resources_allow_list - .contains(&wanted_resource) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn can_access_kubernetes_resource_empty_allow_list() { - let policy = - Policy::new("test".to_string(), None, None, None).expect("cannot create policy"); - - let requested_resource = ContextAwareResource { - api_version: "v1".to_string(), - kind: "Secret".to_string(), - }; - - assert!(!policy.can_access_kubernetes_resource( - &requested_resource.api_version, - &requested_resource.kind - )); - } - - #[test] - fn can_access_kubernetes_resource_denied() { - let requested_resource = ContextAwareResource { - api_version: "v1".to_string(), - kind: "Secret".to_string(), - }; - - let mut allowed_resources = BTreeSet::new(); - allowed_resources.insert(ContextAwareResource { - api_version: "v1".to_string(), - kind: "Pod".to_string(), - }); - - let policy = Policy::new("test".to_string(), None, None, Some(allowed_resources)) - .expect("cannot create policy"); - - assert!(!policy.can_access_kubernetes_resource( - &requested_resource.api_version, - &requested_resource.kind - )); - } - - #[test] - fn can_access_kubernetes_resource_allowed() { - let requested_resource = ContextAwareResource { - api_version: "v1".to_string(), - kind: "Secret".to_string(), - }; - - let mut allowed_resources = BTreeSet::new(); - allowed_resources.insert(ContextAwareResource { - api_version: "v1".to_string(), - kind: "Pod".to_string(), - }); - allowed_resources.insert(ContextAwareResource { - api_version: "v1".to_string(), - kind: "Secret".to_string(), - }); - - let policy = Policy::new("test".to_string(), None, None, Some(allowed_resources)) - .expect("cannot create policy"); - - assert!(policy.can_access_kubernetes_resource( - &requested_resource.api_version, - &requested_resource.kind - )); - } - - #[test] - fn policy_fields_test() { - let id = "test".to_string(); - let policy_id = Some(1); - let callback_channel = None; - let ctx_aware_resources_allow_list = BTreeSet::new(); - - let policy = Policy::new( - id.clone(), - policy_id, - callback_channel.clone(), - Some(ctx_aware_resources_allow_list.clone()), - ) - .expect("cannot create policy"); - - assert!(policy.id == id); - assert!(policy.instance_id == policy_id); - assert!(policy.ctx_aware_resources_allow_list == ctx_aware_resources_allow_list); - } -} diff --git a/src/policy_evaluator.rs b/src/policy_evaluator.rs index aad5b2c9..2c86bde5 100644 --- a/src/policy_evaluator.rs +++ b/src/policy_evaluator.rs @@ -7,8 +7,9 @@ use std::{convert::TryFrom, fmt}; use crate::admission_request::AdmissionRequest; use crate::admission_response::AdmissionResponse; -use crate::policy::Policy; +use crate::evaluation_context::EvaluationContext; use crate::runtimes::rego::Runtime as BurregoRuntime; +use crate::runtimes::wapc; use crate::runtimes::wapc::Runtime as WapcRuntime; use crate::runtimes::wasi_cli::Runtime as WasiRuntime; use crate::runtimes::Runtime; @@ -43,7 +44,7 @@ pub enum ValidateRequest { } impl ValidateRequest { - pub(crate) fn uid(&self) -> &str { + pub fn uid(&self) -> &str { match self { ValidateRequest::Raw(raw_req) => raw_req .get("uid") @@ -74,64 +75,68 @@ impl TryFrom for RegoPolicyExecutionMode { } } -pub(crate) type PolicySettings = serde_json::Map; - -pub trait Evaluator { - fn validate(&mut self, request: ValidateRequest) -> AdmissionResponse; - fn validate_settings(&mut self) -> SettingsValidationResponse; - fn protocol_version(&mut self) -> Result; - fn policy_id(&self) -> String; -} +/// Settings specified by the user for a given policy. +pub type PolicySettings = serde_json::Map; pub struct PolicyEvaluator { - pub(crate) runtime: Runtime, - pub(crate) settings: PolicySettings, - pub policy: Policy, + runtime: Runtime, + worker_id: u64, + policy_id: String, } -impl fmt::Debug for PolicyEvaluator { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("PolicyEvaluator") - .field("id", &self.policy.id) - .field("settings", &self.settings) - .finish() +impl PolicyEvaluator { + pub(crate) fn new(policy_id: &str, worker_id: u64, runtime: Runtime) -> Self { + Self { + runtime, + worker_id, + policy_id: policy_id.to_owned(), + } } -} -impl Evaluator for PolicyEvaluator { - fn policy_id(&self) -> String { - self.policy.id.clone() + #[cfg(test)] + pub(crate) fn runtime(&self) -> &Runtime { + &self.runtime } - #[tracing::instrument(skip(request))] - fn validate(&mut self, request: ValidateRequest) -> AdmissionResponse { + pub fn policy_id(&self) -> String { + self.policy_id.clone() + } + + #[tracing::instrument(skip(request, eval_ctx))] + pub fn validate( + &mut self, + request: ValidateRequest, + settings: &PolicySettings, + eval_ctx: &EvaluationContext, + ) -> AdmissionResponse { match self.runtime { Runtime::Wapc(ref mut wapc_host) => { - WapcRuntime(wapc_host).validate(&self.settings, &request) + wapc::evaluation_context_registry::set_worker_ctx(self.worker_id, eval_ctx); + WapcRuntime(wapc_host).validate(settings, &request) } Runtime::Burrego(ref mut burrego_evaluator) => { let kube_ctx = burrego_evaluator.build_kubernetes_context( - self.policy.callback_channel.as_ref(), - &self.policy.ctx_aware_resources_allow_list, + eval_ctx.callback_channel.as_ref(), + &eval_ctx.ctx_aware_resources_allow_list, ); match kube_ctx { - Ok(ctx) => { - BurregoRuntime(burrego_evaluator).validate(&self.settings, &request, &ctx) - } + Ok(ctx) => BurregoRuntime(burrego_evaluator).validate(settings, &request, &ctx), Err(e) => { AdmissionResponse::reject(request.uid().to_string(), e.to_string(), 500) } } } - Runtime::Cli(ref mut cli_stack) => { - WasiRuntime(cli_stack).validate(&self.settings, &request) - } + Runtime::Cli(ref mut cli_stack) => WasiRuntime(cli_stack).validate(settings, &request), } } - #[tracing::instrument] - fn validate_settings(&mut self) -> SettingsValidationResponse { - let settings_str = match serde_json::to_string(&self.settings) { + #[tracing::instrument(skip(eval_ctx))] + pub fn validate_settings( + &mut self, + settings: &PolicySettings, + eval_ctx: &EvaluationContext, + ) -> SettingsValidationResponse { + let settings_str = match serde_json::to_string(settings) { Ok(settings) => settings, Err(err) => { return SettingsValidationResponse { @@ -143,6 +148,7 @@ impl Evaluator for PolicyEvaluator { match self.runtime { Runtime::Wapc(ref mut wapc_host) => { + wapc::evaluation_context_registry::set_worker_ctx(self.worker_id, eval_ctx); WapcRuntime(wapc_host).validate_settings(settings_str) } Runtime::Burrego(ref mut burrego_evaluator) => { @@ -154,7 +160,7 @@ impl Evaluator for PolicyEvaluator { } } - fn protocol_version(&mut self) -> Result { + pub fn protocol_version(&mut self) -> Result { match &mut self.runtime { Runtime::Wapc(ref mut wapc_host) => WapcRuntime(wapc_host).protocol_version(), _ => Err(anyhow!( @@ -164,6 +170,18 @@ impl Evaluator for PolicyEvaluator { } } +impl fmt::Debug for PolicyEvaluator { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let runtime = self.runtime.to_string(); + + f.debug_struct("PolicyEvaluator") + .field("policy_id", &self.policy_id) + .field("worker_id", &self.worker_id) + .field("runtime", &runtime) + .finish() + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/policy_evaluator_builder.rs b/src/policy_evaluator_builder.rs index 1e4300c2..ffe339e3 100644 --- a/src/policy_evaluator_builder.rs +++ b/src/policy_evaluator_builder.rs @@ -2,14 +2,15 @@ use anyhow::{anyhow, Result}; use std::collections::BTreeSet; use std::convert::TryInto; use std::path::Path; +use std::sync::{Arc, RwLock}; use tokio::sync::mpsc; use wasmtime_provider::wasmtime; use crate::callback_requests::CallbackRequest; -use crate::policy::Policy; +use crate::evaluation_context::EvaluationContext; use crate::policy_evaluator::{PolicyEvaluator, PolicyExecutionMode}; use crate::policy_metadata::ContextAwareResource; -use crate::runtimes::wapc::WAPC_POLICY_MAPPING; +use crate::runtimes::wapc::evaluation_context_registry::register_policy; use crate::runtimes::{rego::BurregoStack, wapc::WapcStack, wasi_cli, Runtime}; /// Configure behavior of wasmtime [epoch-based interruptions](https://docs.rs/wasmtime/latest/wasmtime/struct.Config.html#method.epoch_interruption) @@ -33,11 +34,11 @@ pub(crate) struct EpochDeadlines { pub struct PolicyEvaluatorBuilder { engine: Option, policy_id: String, + worker_id: u64, policy_file: Option, policy_contents: Option>, policy_module: Option, execution_mode: Option, - settings: Option>, callback_channel: Option>, wasmtime_cache: bool, epoch_deadlines: Option, @@ -45,11 +46,16 @@ pub struct PolicyEvaluatorBuilder { } impl PolicyEvaluatorBuilder { - /// Create a new PolicyEvaluatorBuilder object. The `policy_id` must be - /// specified. - pub fn new(policy_id: String) -> PolicyEvaluatorBuilder { + /// Create a new PolicyEvaluatorBuilder object. + /// * `policy_id`: unique identifier of the policy. This is mostly relevant for PolicyServer. + /// In this case, this is the value used to identify the policy inside of the `policies.yml` + /// file + /// * `worker_id`: unique identifier of the worker that is going to evaluate the policy. This + /// is mostly relevant for PolicyServer + pub fn new(policy_id: String, worker_id: u64) -> PolicyEvaluatorBuilder { PolicyEvaluatorBuilder { policy_id, + worker_id, ..Default::default() } } @@ -103,15 +109,6 @@ impl PolicyEvaluatorBuilder { self } - /// Set the settings the policy will use at evaluation time - pub fn settings( - mut self, - s: Option>, - ) -> PolicyEvaluatorBuilder { - self.settings = s; - self - } - /// Set the list of Kubernetes resources the policy can have access to pub fn context_aware_resources_allowed( mut self, @@ -245,47 +242,21 @@ impl PolicyEvaluatorBuilder { let execution_mode = self.execution_mode.unwrap(); - let (policy, runtime) = match execution_mode { - PolicyExecutionMode::KubewardenWapc => { - let wapc_stack = WapcStack::new(engine, module, self.epoch_deadlines)?; - - let policy = Self::from_contents_internal( - self.policy_id.clone(), - self.callback_channel.clone(), - Some(self.ctx_aware_resources_allow_list.clone()), - || Some(wapc_stack.wapc_host_id()), - Policy::new, - execution_mode, - )?; - - let policy_runtime = Runtime::Wapc(wapc_stack); - (policy, policy_runtime) - } + let runtime = match execution_mode { + PolicyExecutionMode::KubewardenWapc => create_wapc_runtime( + &self.policy_id, + self.worker_id, + engine, + module, + self.epoch_deadlines, + self.callback_channel.clone(), + &self.ctx_aware_resources_allow_list, + )?, PolicyExecutionMode::Wasi => { let cli_stack = wasi_cli::Stack::new(engine, module, self.epoch_deadlines)?; - - let policy = Self::from_contents_internal( - self.policy_id.clone(), - None, // callback_channel is not used by WASI policies - None, - || None, - Policy::new, - execution_mode, - )?; - - let policy_runtime = Runtime::Cli(cli_stack); - (policy, policy_runtime) + Runtime::Cli(cli_stack) } PolicyExecutionMode::Opa | PolicyExecutionMode::OpaGatekeeper => { - let policy = Self::from_contents_internal( - self.policy_id.clone(), - self.callback_channel.clone(), - Some(self.ctx_aware_resources_allow_list.clone()), - || None, - Policy::new, - execution_mode, - )?; - let mut builder = burrego::EvaluatorBuilder::default() .engine(&engine) .module(module) @@ -296,157 +267,110 @@ impl PolicyEvaluatorBuilder { } let evaluator = builder.build()?; - let policy_runtime = Runtime::Burrego(BurregoStack { + Runtime::Burrego(BurregoStack { evaluator, entrypoint_id: 0, // currently hardcoded to this value policy_execution_mode: execution_mode.try_into()?, - }); - - (policy, policy_runtime) + }) } }; - Ok(PolicyEvaluator { + Ok(PolicyEvaluator::new( + &self.policy_id, + self.worker_id, runtime, - policy, - settings: self.settings.clone().unwrap_or_default(), - }) + )) } +} - fn from_contents_internal( - id: String, - callback_channel: Option>, - ctx_aware_resources_allow_list: Option>, - engine_initializer: E, - policy_initializer: P, - policy_execution_mode: PolicyExecutionMode, - ) -> Result - where - E: Fn() -> Option, - P: Fn( - String, - Option, - Option>, - Option>, - ) -> Result, - { - let instance_id = engine_initializer(); - let policy = policy_initializer( - id, - instance_id, - callback_channel, - ctx_aware_resources_allow_list, - )?; - if policy_execution_mode == PolicyExecutionMode::KubewardenWapc { - WAPC_POLICY_MAPPING - .write() - .expect("cannot write to global WAPC_POLICY_MAPPING") - .insert( - instance_id.ok_or_else(|| anyhow!("invalid policy id"))?, - policy.clone(), - ); - } - Ok(policy) - } +fn create_wapc_runtime( + policy_id: &str, + worker_id: u64, + engine: wasmtime::Engine, + module: wasmtime::Module, + epoch_deadlines: Option, + callback_channel: Option>, + ctx_aware_resources_allow_list: &BTreeSet, +) -> Result { + let wapc_stack = WapcStack::new(engine, module, epoch_deadlines)?; + let eval_ctx = Arc::new(RwLock::new(EvaluationContext { + policy_id: policy_id.to_owned(), + callback_channel, + ctx_aware_resources_allow_list: ctx_aware_resources_allow_list.clone(), + })); + register_policy(wapc_stack.wapc_host_id(), worker_id, eval_ctx); + + Ok(Runtime::Wapc(wapc_stack)) } #[cfg(test)] mod tests { use super::*; + use crate::runtimes::wapc::evaluation_context_registry::{ + get_eval_ctx, get_worker_id, tests::is_wapc_instance_registered, + }; #[test] - fn policy_is_registered_in_the_mapping() -> Result<()> { - let policy_name = "policy_is_registered_in_the_mapping"; - - // We cannot set policy.id at build time, because some attributes - // of Policy are private. - let mut policy = Policy::default(); - policy.id = policy_name.to_string(); + fn wapc_policy_is_registered() { + let policy_id = "a-policy"; + let worker_id = 0; - let policy_id = 1; - - PolicyEvaluatorBuilder::from_contents_internal( - "mock_policy".to_string(), - None, - None, - || Some(policy_id), - |_, _, _, _| Ok(policy.clone()), - PolicyExecutionMode::KubewardenWapc, - )?; - - let policy_mapping = WAPC_POLICY_MAPPING.read().unwrap(); - let found = policy_mapping - .iter() - .find(|(_id, policy)| policy.id == policy_name); + // we need a real waPC policy, we don't care about the contents yet + let engine = wasmtime::Engine::default(); + let wat = include_bytes!("../test_data/endless_wasm/wapc_endless_loop.wat"); + let module = wasmtime::Module::new(&engine, wat).expect("cannot compile WAT to wasm"); - assert!(found.is_some()); + let epoch_deadlines = None; + let callback_channel = None; + let ctx_aware_resources_allow_list: BTreeSet = BTreeSet::new(); - Ok(()) - } + let runtime = create_wapc_runtime( + policy_id, + worker_id, + engine, + module, + epoch_deadlines, + callback_channel, + &ctx_aware_resources_allow_list, + ) + .expect("cannot create wapc runtime"); + let wapc_stack = match runtime { + Runtime::Wapc(stack) => stack, + _ => panic!("not the runtime I was expecting"), + }; - fn find_wapc_policy_id(policy: &Policy) -> Option { - let map = WAPC_POLICY_MAPPING - .read() - .expect("cannot get READ access to WAPC_POLICY_MAPPING"); - map.iter() - .find(|(_, v)| *v == policy) - .map(|(k, _)| k.to_owned()) - } + assert_eq!( + worker_id, + get_worker_id(wapc_stack.wapc_host_id()).expect("didn't find policy") + ); - fn is_wapc_instance_registered(policy_id: u64) -> bool { - let map = WAPC_POLICY_MAPPING - .read() - .expect("cannot get READ access to WAPC_POLICY_MAPPING"); - map.get(&policy_id).is_some() + // this will panic if the evaluation context has not been registered + let eval_ctx = get_eval_ctx(wapc_stack.wapc_host_id()); + assert_eq!(eval_ctx.policy_id, policy_id); } #[test] - fn policy_wapc_mapping_is_cleaned_when_the_evaluator_is_dropped() { - // we need a real WASM module, we don't care about the contents yet + fn wapc_policy_is_removed_from_registry_when_the_evaluator_is_dropped() { + // we need a real waPC policy, we don't care about the contents yet + let worker_id = 0; let engine = wasmtime::Engine::default(); let wat = include_bytes!("../test_data/endless_wasm/wapc_endless_loop.wat"); let module = wasmtime::Module::new(&engine, wat).expect("cannot compile WAT to wasm"); - let builder = PolicyEvaluatorBuilder::new("test".to_string()) + let builder = PolicyEvaluatorBuilder::new("test".to_string(), worker_id) .execution_mode(PolicyExecutionMode::KubewardenWapc) .engine(engine) .policy_module(module); let evaluator = builder.build().expect("cannot create evaluator"); - let policy_id = - find_wapc_policy_id(&evaluator.policy).expect("cannot find the wapc we just created"); + let wapc_stack = match evaluator.runtime() { + Runtime::Wapc(ref stack) => stack, + _ => panic!("not the runtime I was expecting"), + }; + let wapc_id = wapc_stack.wapc_host_id(); drop(evaluator); - assert!(!is_wapc_instance_registered(policy_id)); - } - - #[test] - fn policy_is_not_registered_in_the_mapping_if_not_wapc() -> Result<()> { - let policy_name = "policy_is_not_registered_in_the_mapping_if_not_wapc"; - - // We cannot set policy.id at build time, because some attributes - // of Policy are private. - let mut policy = Policy::default(); - policy.id = policy_name.to_string(); - - let policy_id = 1; - - PolicyEvaluatorBuilder::from_contents_internal( - policy_name.to_string(), - None, - None, - || Some(policy_id), - |_, _, _, _| Ok(policy.clone()), - PolicyExecutionMode::OpaGatekeeper, - )?; - - let policy_mapping = WAPC_POLICY_MAPPING.read().unwrap(); - let found = policy_mapping - .iter() - .find(|(_id, policy)| policy.id == policy_name); - - assert!(found.is_none()); - Ok(()) + assert!(!is_wapc_instance_registered(wapc_id)); } } diff --git a/src/policy_tracing.rs b/src/policy_tracing.rs index a8c719a5..554c7464 100644 --- a/src/policy_tracing.rs +++ b/src/policy_tracing.rs @@ -2,7 +2,7 @@ use anyhow::{anyhow, Result}; use serde::{Deserialize, Serialize}; use tracing::{event, Level}; -use crate::policy::Policy; +use crate::evaluation_context::EvaluationContext; #[derive(Debug, Serialize)] enum PolicyLogEntryLevel { @@ -38,7 +38,7 @@ struct PolicyLogEntry { data: Option>, } -impl Policy { +impl EvaluationContext { #[tracing::instrument(name = "policy_log", skip(contents))] pub(crate) fn log(&self, contents: &[u8]) -> Result<()> { let log_entry: PolicyLogEntry = serde_json::from_slice(contents)?; diff --git a/src/runtimes/mod.rs b/src/runtimes/mod.rs index f420b6b2..bf86295f 100644 --- a/src/runtimes/mod.rs +++ b/src/runtimes/mod.rs @@ -1,3 +1,7 @@ +use std::fmt::Display; + +use crate::policy_evaluator::RegoPolicyExecutionMode; + pub(crate) mod rego; pub(crate) mod wapc; pub(crate) mod wasi_cli; @@ -7,3 +11,20 @@ pub(crate) enum Runtime { Burrego(rego::BurregoStack), Cli(wasi_cli::Stack), } + +impl Display for Runtime { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Runtime::Cli(_) => write!(f, "wasi"), + Runtime::Wapc(_) => write!(f, "wapc"), + Runtime::Burrego(stack) => match stack.policy_execution_mode { + RegoPolicyExecutionMode::Opa => { + write!(f, "OPA") + } + RegoPolicyExecutionMode::Gatekeeper => { + write!(f, "Gatekeeper") + } + }, + } + } +} diff --git a/src/runtimes/wapc/callback.rs b/src/runtimes/wapc/callback.rs index 3de4e9e5..febb18df 100644 --- a/src/runtimes/wapc/callback.rs +++ b/src/runtimes/wapc/callback.rs @@ -9,29 +9,22 @@ use tracing::{debug, error, warn}; use crate::callback_handler::verify_certificate; use crate::callback_requests::{CallbackRequest, CallbackRequestType, CallbackResponse}; -use crate::runtimes::wapc::mapping::get_policy; +use crate::runtimes::wapc::evaluation_context_registry::get_eval_ctx; pub(crate) fn host_callback( - policy_id: u64, + wapc_id: u64, binding: &str, namespace: &str, operation: &str, payload: &[u8], ) -> Result, Box> { + //TODO: the code could be DRY-ed using some macros match binding { "kubewarden" => match namespace { "tracing" => match operation { "log" => { - let policy = get_policy(policy_id).map_err(|e| { - error!( - ?policy_id, - ?binding, - ?namespace, - ?operation, - error = ?e, "Cannot find requested policy"); - e - })?; - if let Err(e) = policy.log(payload) { + let eval_ctx = get_eval_ctx(wapc_id); + if let Err(e) = eval_ctx.log(payload) { let p = String::from_utf8(payload.to_vec()).unwrap_or_else(|e| e.to_string()); error!( @@ -58,7 +51,7 @@ pub(crate) fn host_callback( response_channel: tx, }; - send_request_and_wait_for_response(policy_id, binding, operation, req, rx) + send_request_and_wait_for_response(wapc_id, binding, operation, req, rx) } "v2/verify" => { let req: SigstoreVerificationInputV2 = @@ -70,12 +63,12 @@ pub(crate) fn host_callback( response_channel: tx, }; - send_request_and_wait_for_response(policy_id, binding, operation, req, rx) + send_request_and_wait_for_response(wapc_id, binding, operation, req, rx) } "v1/manifest_digest" => { let image: String = serde_json::from_slice(payload.to_vec().as_ref())?; debug!( - policy_id, + wapc_id, binding, operation, image = image.as_str(), @@ -86,7 +79,7 @@ pub(crate) fn host_callback( request: CallbackRequestType::OciManifestDigest { image }, response_channel: tx, }; - send_request_and_wait_for_response(policy_id, binding, operation, req, rx) + send_request_and_wait_for_response(wapc_id, binding, operation, req, rx) } _ => { error!("unknown operation: {}", operation); @@ -97,7 +90,7 @@ pub(crate) fn host_callback( "v1/dns_lookup_host" => { let host: String = serde_json::from_slice(payload.to_vec().as_ref())?; debug!( - policy_id, + wapc_id, binding, operation, ?host, @@ -108,7 +101,7 @@ pub(crate) fn host_callback( request: CallbackRequestType::DNSLookupHost { host }, response_channel: tx, }; - send_request_and_wait_for_response(policy_id, binding, operation, req, rx) + send_request_and_wait_for_response(wapc_id, binding, operation, req, rx) } _ => { error!("unknown operation: {}", operation); @@ -134,24 +127,16 @@ pub(crate) fn host_callback( }, "kubernetes" => match operation { "list_resources_by_namespace" => { - let policy = get_policy(policy_id).map_err(|e| { - error!( - policy_id, - ?binding, - ?namespace, - ?operation, - error = ?e, "Cannot find requested policy"); - e - })?; + let eval_ctx = get_eval_ctx(wapc_id); let req: ListResourcesByNamespaceRequest = serde_json::from_slice(payload.to_vec().as_ref())?; - if !policy.can_access_kubernetes_resource(&req.api_version, &req.kind) { + if !eval_ctx.can_access_kubernetes_resource(&req.api_version, &req.kind) { error!( - policy = policy.id, + policy = eval_ctx.policy_id, resource_requested = format!("{}/{}", req.api_version, req.kind), - resources_allowed = ?policy.ctx_aware_resources_allow_list, + resources_allowed = ?eval_ctx.ctx_aware_resources_allow_list, "Policy tried to access a Kubernetes resource it doesn't have access to"); return Err(format!( "Policy has not been granted access to Kubernetes {}/{} resources. The violation has been reported.", @@ -160,7 +145,7 @@ pub(crate) fn host_callback( } debug!( - policy_id, + wapc_id, binding, operation, ?req, @@ -171,26 +156,18 @@ pub(crate) fn host_callback( request: CallbackRequestType::from(req), response_channel: tx, }; - send_request_and_wait_for_response(policy_id, binding, operation, req, rx) + send_request_and_wait_for_response(wapc_id, binding, operation, req, rx) } "list_resources_all" => { - let policy = get_policy(policy_id).map_err(|e| { - error!( - policy_id, - ?binding, - ?namespace, - ?operation, - error = ?e, "Cannot find requested policy"); - e - })?; + let eval_ctx = get_eval_ctx(wapc_id); let req: ListAllResourcesRequest = serde_json::from_slice(payload.to_vec().as_ref())?; - if !policy.can_access_kubernetes_resource(&req.api_version, &req.kind) { + if !eval_ctx.can_access_kubernetes_resource(&req.api_version, &req.kind) { error!( - policy = policy.id, + policy = eval_ctx.policy_id, resource_requested = format!("{}/{}", req.api_version, req.kind), - resources_allowed = ?policy.ctx_aware_resources_allow_list, + resources_allowed = ?eval_ctx.ctx_aware_resources_allow_list, "Policy tried to access a Kubernetes resource it doesn't have access to"); return Err(format!( "Policy has not been granted access to Kubernetes {}/{} resources. The violation has been reported.", @@ -199,7 +176,7 @@ pub(crate) fn host_callback( } debug!( - policy_id, + wapc_id, binding, operation, ?req, @@ -210,26 +187,18 @@ pub(crate) fn host_callback( request: CallbackRequestType::from(req), response_channel: tx, }; - send_request_and_wait_for_response(policy_id, binding, operation, req, rx) + send_request_and_wait_for_response(wapc_id, binding, operation, req, rx) } "get_resource" => { - let policy = get_policy(policy_id).map_err(|e| { - error!( - policy_id, - ?binding, - ?namespace, - ?operation, - error = ?e, "Cannot find requested policy"); - e - })?; + let eval_ctx = get_eval_ctx(wapc_id); let req: GetResourceRequest = serde_json::from_slice(payload.to_vec().as_ref())?; - if !policy.can_access_kubernetes_resource(&req.api_version, &req.kind) { + if !eval_ctx.can_access_kubernetes_resource(&req.api_version, &req.kind) { error!( - policy = policy.id, + policy = eval_ctx.policy_id, resource_requested = format!("{}/{}", req.api_version, req.kind), - resources_allowed = ?policy.ctx_aware_resources_allow_list, + resources_allowed = ?eval_ctx.ctx_aware_resources_allow_list, "Policy tried to access a Kubernetes resource it doesn't have access to"); return Err(format!( "Policy has not been granted access to Kubernetes {}/{} resources. The violation has been reported.", @@ -238,7 +207,7 @@ pub(crate) fn host_callback( } debug!( - policy_id, + wapc_id, binding, operation, ?req, @@ -249,7 +218,7 @@ pub(crate) fn host_callback( request: CallbackRequestType::from(req), response_channel: tx, }; - send_request_and_wait_for_response(policy_id, binding, operation, req, rx) + send_request_and_wait_for_response(wapc_id, binding, operation, req, rx) } _ => { error!(namespace, operation, "unknown operation"); @@ -270,13 +239,13 @@ pub(crate) fn host_callback( field_selector: None, }; - warn!(policy_id, ?req, "Usage of deprecated `ClusterContext`"); + warn!(wapc_id, ?req, "Usage of deprecated `ClusterContext`"); let (tx, rx) = oneshot::channel::>(); let req = CallbackRequest { request: req, response_channel: tx, }; - send_request_and_wait_for_response(policy_id, binding, operation, req, rx) + send_request_and_wait_for_response(wapc_id, binding, operation, req, rx) } "namespaces" => { let req = CallbackRequestType::KubernetesListResourceAll { @@ -286,13 +255,13 @@ pub(crate) fn host_callback( field_selector: None, }; - warn!(policy_id, ?req, "Usage of deprecated `ClusterContext`"); + warn!(wapc_id, ?req, "Usage of deprecated `ClusterContext`"); let (tx, rx) = oneshot::channel::>(); let req = CallbackRequest { request: req, response_channel: tx, }; - send_request_and_wait_for_response(policy_id, binding, operation, req, rx) + send_request_and_wait_for_response(wapc_id, binding, operation, req, rx) } "services" => { let req = CallbackRequestType::KubernetesListResourceAll { @@ -302,13 +271,13 @@ pub(crate) fn host_callback( field_selector: None, }; - warn!(policy_id, ?req, "Usage of deprecated `ClusterContext`"); + warn!(wapc_id, ?req, "Usage of deprecated `ClusterContext`"); let (tx, rx) = oneshot::channel::>(); let req = CallbackRequest { request: req, response_channel: tx, }; - send_request_and_wait_for_response(policy_id, binding, operation, req, rx) + send_request_and_wait_for_response(wapc_id, binding, operation, req, rx) } _ => { error!("unknown namespace: {}", namespace); @@ -329,20 +298,20 @@ fn send_request_and_wait_for_response( req: CallbackRequest, mut rx: Receiver>, ) -> Result, Box> { - let policy = get_policy(policy_id)?; + let eval_ctx = get_eval_ctx(policy_id); - let cb_channel: mpsc::Sender = if let Some(c) = policy.callback_channel.clone() - { - Ok(c) - } else { - error!( - policy_id, - binding, operation, "Cannot process waPC request: callback channel not provided" - ); - Err(anyhow!( - "Cannot process waPC request: callback channel not provided" - )) - }?; + let cb_channel: mpsc::Sender = + if let Some(c) = eval_ctx.callback_channel.clone() { + Ok(c) + } else { + error!( + policy_id, + binding, operation, "Cannot process waPC request: callback channel not provided" + ); + Err(anyhow!( + "Cannot process waPC request: callback channel not provided" + )) + }?; let send_result = cb_channel.try_send(req); if let Err(e) = send_result { diff --git a/src/runtimes/wapc/evaluation_context_registry.rs b/src/runtimes/wapc/evaluation_context_registry.rs new file mode 100644 index 00000000..6a6ae3cb --- /dev/null +++ b/src/runtimes/wapc/evaluation_context_registry.rs @@ -0,0 +1,329 @@ +/// This module provides helper functions and data structures used to keep track of the waPC +/// policies being instantiated. This is all modelled to optimize the Policy Server scenario, +/// although everything works fine also with kwctl. +/// +/// +/// ## The problem +/// +/// In the Policy Server scenario, we have multiple workers. Each one of them living inside of +/// their own dedicated thread. +/// +/// Given a unique list of waPC policies, a worker will instantiate only one wapc runtime per +/// policy. This is done to reduce the amount of memory consumed by Policy Server. +/// Later, when inside of the [`host_callback`](crate::runtimes::wapc::callback::host_callback) function, +/// we receive a `wapc_id` and the details about an operation that has to be run by the Wasm +/// host. When that happens, we need to access some auxiliary information about the policy being +/// evaluated. For example, if a policy is requesting "list all the Kubernetes Secret objects +/// inside of the `kube-system` Namespace", we need to know if the administrator granted access to +/// Kubernetes Secret to the policy. +/// The auxiliary information are stored inside of a `EvaluationContext` object. +/// +/// ## The data structures +/// +/// This module defines two global variables that hold all the information required by `host_callback` to +/// obtain the `EvaluationContext` instance associated with a waPC policy. All of that by just +/// doing a series of lookups based on the `wapc_id` associated with the policy. +/// +/// The `WAPC_ID_TO_WORKER_ID` structure contains the relationship waPC policy -> worker. +/// Given a waPC policy ID, we can discover to which worker it belongs. +/// Then, using the `WORKER_ID_TO_CTX` structure, we find the `EvaluationContext` associated with a +/// certain worker. +/// +/// ## Workflow +/// +/// ### Policy registration +/// +/// As soon as a waPC policy is created, the following information have to be inserted into the +/// registry: +/// - wapc ID +/// - ID of the worker that owns the policy +/// - The first `EvaluationContext` to be used +/// +/// ### Validate request/settings +/// +/// Prior to invoking the `validate`/`validate_settings` functions exposed by a policy, the worker +/// must inform the registry about the `EvaluationContext` that it's going to be used during the +/// evaluation. This ensures the `host_callback`, if ever called by policy, obtains the right +/// auxiliary information. +use anyhow::{anyhow, Result}; +use lazy_static::lazy_static; +use std::{ + collections::HashMap, + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, RwLock, + }, +}; +use tracing::debug; + +use crate::evaluation_context::EvaluationContext; + +lazy_static! { + /// A map with wapc_id as key, and the worker_id as value. It allows us to know to which + /// worker a waPC policy belongs. When inside of the + /// [`host_callback`](crate::runtimes::wapc::callback::host_callback) function, + /// we need to know the `EvaluationContext` to be used + static ref WAPC_ID_TO_WORKER_ID: RwLock> = RwLock::new(HashMap::new()); + + /// A Map with worker_id as key, and the current `EvaluationContext` as value + static ref WORKER_ID_TO_CTX: RwLock>>> = + RwLock::new(HashMap::new()); + + /// A Map with worker_id as key, and a counter as value. This is used to keep track of how + /// many waPC policies are currently assigned to a worker. This is used to garbage collect + /// entries inside of the WORKER_ID_TO_CTX map. More details inside of the `unregister_policy` + /// function below + static ref WORKER_ID_TO_ACTIVE_POLICIES_COUNTER: RwLock>> = RwLock::new(HashMap::new()); +} + +/// Register a waPC policy inside of the global registry +pub(crate) fn register_policy( + wapc_id: u64, + worker_id: u64, + eval_ctx: Arc>, +) { + let mut map = WAPC_ID_TO_WORKER_ID.write().unwrap(); + map.insert(wapc_id, worker_id); + debug!( + wapc_id, + worker_id, "registered waPC policy inside of global registry" + ); + + let mut map = WORKER_ID_TO_CTX.write().unwrap(); + map.insert(worker_id, eval_ctx); + debug!(worker_id, "registered evaluation context"); + + let mut map = WORKER_ID_TO_ACTIVE_POLICIES_COUNTER.write().unwrap(); + map.entry(worker_id) + .and_modify(|counter| { + let _ = counter.fetch_add(1, Ordering::Relaxed); + }) + .or_insert(Arc::new(AtomicU64::new(1))); +} + +/// Set the evaluation context used by worker. To be invoked **before** starting a policy +/// `validate` or `validate_settings` operation +pub(crate) fn set_worker_ctx(worker_id: u64, evaluation_context: &EvaluationContext) { + let map = WORKER_ID_TO_CTX.read().unwrap(); + let mut ctx = map + .get(&worker_id) + .expect("cannot find worker") + .write() + .unwrap(); + ctx.copy_from(evaluation_context); +} + +/// Removes a policy from the registry. To be used only when the policy is no longer being used +pub(crate) fn unregister_policy(wapc_id: u64) { + let mut map = WAPC_ID_TO_WORKER_ID.write().unwrap(); + let worker_id = match map.remove(&wapc_id) { + Some(id) => id, + None => { + // Should not happen, the policy has already been dropped or wasn't known + return; + } + }; + + let mut map = WORKER_ID_TO_ACTIVE_POLICIES_COUNTER.write().unwrap(); + let counter = match map.get_mut(&worker_id) { + Some(counter) => counter, + None => { + // Should not happen, the worker_id has already been unregistered + return; + } + }; + + let active_policies_for_worker = counter.fetch_sub(1, Ordering::Relaxed); + if active_policies_for_worker == 0 { + // We can remove the worker entry from WORKER_ID_TO_CTX, which will release the + // reference made against the EvaluationContext, avoiding a memory leak of this struct + let _ = WORKER_ID_TO_CTX.write().unwrap().remove(&worker_id); + } +} + +/// Find which worker owns the given waPC policy +pub(crate) fn get_worker_id(wapc_id: u64) -> Result { + let mapping = WAPC_ID_TO_WORKER_ID.read().unwrap(); + + mapping + .get(&wapc_id) + .ok_or_else(|| anyhow!("cannot find policy with ID {}", wapc_id)) + .cloned() +} + +/// Given a waPC policy ID, find the evaluation context associated with it +pub(crate) fn get_eval_ctx(wapc_id: u64) -> EvaluationContext { + let worker_id = { + let map = WAPC_ID_TO_WORKER_ID.read().unwrap(); + + map.get(&wapc_id) + .expect("cannot find policy inside of WAPC_ID_TO_WORKER_ID") + .to_owned() + }; + + let map = WORKER_ID_TO_CTX.read().unwrap(); + let ctx = map + .get(&worker_id) + .expect("cannot find worker") + .read() + .unwrap(); + ctx.to_owned() +} + +#[cfg(test)] +pub(crate) mod tests { + use super::*; + + use std::collections::BTreeSet; + use test_context::{test_context, TestContext}; + + pub(crate) fn is_wapc_instance_registered(wapc_id: u64) -> bool { + let map = WAPC_ID_TO_WORKER_ID.read().unwrap(); + map.contains_key(&wapc_id) + } + + struct TestCtx { + evaluation_context: Arc>, + } + + impl TestContext for TestCtx { + fn setup() -> TestCtx { + let evaluation_context = Arc::new(RwLock::new(EvaluationContext { + policy_id: "test".to_string(), + callback_channel: None, + ctx_aware_resources_allow_list: BTreeSet::new(), + })); + + TestCtx { evaluation_context } + } + + fn teardown(self) { + // wipe all the registries + WAPC_ID_TO_WORKER_ID.write().unwrap().clear(); + WORKER_ID_TO_CTX.write().unwrap().clear(); + WORKER_ID_TO_ACTIVE_POLICIES_COUNTER + .write() + .unwrap() + .clear(); + } + } + + fn verify_registry_contents( + wapc_id: u64, + expected_worker_id: u64, + expected_worker_policy_counter: u64, + expected_eval_ctx_policy_id: &str, + ) { + assert_eq!(expected_worker_id, get_worker_id(wapc_id).unwrap()); + + let actual_ctx = get_eval_ctx(wapc_id); + assert_eq!( + expected_eval_ctx_policy_id.to_string(), + actual_ctx.policy_id + ); + + let map = WORKER_ID_TO_ACTIVE_POLICIES_COUNTER.read().unwrap(); + let policy_counter = map.get(&expected_worker_id).unwrap(); + assert_eq!( + expected_worker_policy_counter, + policy_counter.load(Ordering::Relaxed) + ); + } + + #[test_context(TestCtx)] + fn register_policy_initializes_internal_structures(test_ctx: &mut TestCtx) { + let wapc_id = 1; + let worker_a_id = 100; + let expected_policy_id = test_ctx + .evaluation_context + .clone() + .read() + .unwrap() + .policy_id + .clone(); + register_policy(wapc_id, worker_a_id, test_ctx.evaluation_context.clone()); + verify_registry_contents(wapc_id, worker_a_id, 1, &expected_policy_id); + + // register another policy against the same worker + let wapc_id = 2; + register_policy(wapc_id, worker_a_id, test_ctx.evaluation_context.clone()); + verify_registry_contents(wapc_id, 100, 2, &expected_policy_id); + + // register another policy against a new worker + let new_wapc_id = 3; + let worker_b_id = 200; + register_policy( + new_wapc_id, + worker_b_id, + test_ctx.evaluation_context.clone(), + ); + verify_registry_contents(new_wapc_id, worker_b_id, 1, &expected_policy_id); + verify_registry_contents(wapc_id, worker_a_id, 2, &expected_policy_id); + } + + #[test_context(TestCtx)] + fn change_worker_context(test_ctx: &mut TestCtx) { + let wapc_id = 1; + let worker_id = 100; + let expected_policy_id = test_ctx + .evaluation_context + .clone() + .read() + .unwrap() + .policy_id + .clone(); + register_policy(wapc_id, worker_id, test_ctx.evaluation_context.clone()); + verify_registry_contents(wapc_id, worker_id, 1, &expected_policy_id); + + let new_policy_id = "a new one".to_string(); + let mut new_evaluation_context: EvaluationContext = { + // the fixture returns a RWLock, we just need a plain EvaluationContext that we can + // change + test_ctx.evaluation_context.clone().write().unwrap().clone() + }; + + new_evaluation_context.policy_id = new_policy_id.clone(); + set_worker_ctx(worker_id, &new_evaluation_context); + verify_registry_contents(wapc_id, worker_id, 1, &new_policy_id); + } + + #[test_context(TestCtx)] + fn test_unregister_policies(test_ctx: &mut TestCtx) { + let worker_id = 100; + let num_of_policies = 10; + for wapc_id in 0..num_of_policies { + register_policy(wapc_id, worker_id, test_ctx.evaluation_context.clone()); + } + + { + // ensure the read lock goes out of scope + let map = WORKER_ID_TO_ACTIVE_POLICIES_COUNTER.read().unwrap(); + let counter = map.get(&worker_id).unwrap().load(Ordering::Relaxed); + assert_eq!(counter, num_of_policies); + } + + // start dropping the policies, one by one + let mut expected_number_of_policies = num_of_policies; + for wapc_id in 0..num_of_policies { + unregister_policy(wapc_id); + expected_number_of_policies -= 1; + + { + // ensure the read lock goes out of scope + let map = WORKER_ID_TO_ACTIVE_POLICIES_COUNTER.read().unwrap(); + let counter = map.get(&worker_id).unwrap().load(Ordering::Relaxed); + assert_eq!(counter, expected_number_of_policies); + } + } + + // the worker should have 0 policies associated + let map = WORKER_ID_TO_ACTIVE_POLICIES_COUNTER.read().unwrap(); + let counter = map.get(&worker_id).unwrap().load(Ordering::Relaxed); + assert_eq!(counter, 0); + + // the structure holding the worker_id => ctx should not have any reference to + // the worker id + let map = WORKER_ID_TO_CTX.read().unwrap(); + assert!(map.get(&worker_id).is_none()); + } +} diff --git a/src/runtimes/wapc/mapping.rs b/src/runtimes/wapc/mapping.rs deleted file mode 100644 index d7a26abf..00000000 --- a/src/runtimes/wapc/mapping.rs +++ /dev/null @@ -1,23 +0,0 @@ -use anyhow::{anyhow, Result}; -use lazy_static::lazy_static; -use std::{collections::HashMap, sync::RwLock}; - -use crate::policy::Policy; - -lazy_static! { - pub(crate) static ref WAPC_POLICY_MAPPING: RwLock> = - RwLock::new(HashMap::with_capacity(64)); -} - -pub(crate) fn get_policy(policy_id: u64) -> Result { - let policy_mapping = WAPC_POLICY_MAPPING.read().map_err(|e| { - anyhow!( - "Cannot obtain read lock access to WAPC_POLICY_MAPPING: {}", - e - ) - })?; - policy_mapping - .get(&policy_id) - .ok_or_else(|| anyhow!("Cannot find policy with ID {}", policy_id)) - .cloned() -} diff --git a/src/runtimes/wapc/mod.rs b/src/runtimes/wapc/mod.rs index 39d7253d..234d1310 100644 --- a/src/runtimes/wapc/mod.rs +++ b/src/runtimes/wapc/mod.rs @@ -1,10 +1,7 @@ mod callback; - +pub(crate) mod evaluation_context_registry; +mod runtime; mod stack; -pub(crate) use stack::WapcStack; -mod runtime; pub(crate) use runtime::Runtime; - -mod mapping; -pub(crate) use mapping::WAPC_POLICY_MAPPING; +pub(crate) use stack::WapcStack; diff --git a/src/runtimes/wapc/stack.rs b/src/runtimes/wapc/stack.rs index c8167dff..833f6a2a 100644 --- a/src/runtimes/wapc/stack.rs +++ b/src/runtimes/wapc/stack.rs @@ -1,8 +1,12 @@ -use anyhow::{anyhow, Result}; -use tracing::warn; +use anyhow::Result; +use std::sync::{Arc, RwLock}; use wasmtime_provider::wasmtime; -use crate::runtimes::wapc::{callback::host_callback, WAPC_POLICY_MAPPING}; +use crate::runtimes::wapc::{ + callback::host_callback, evaluation_context_registry::unregister_policy, +}; + +use super::evaluation_context_registry::{get_eval_ctx, get_worker_id, register_policy}; pub(crate) struct WapcStack { engine: wasmtime::Engine, @@ -41,18 +45,15 @@ impl WapcStack { self.epoch_deadlines, )?; let old_wapc_host_id = self.wapc_host.id(); + let worker_id = get_worker_id(old_wapc_host_id)?; - // Remove the old policy from WAPC_POLICY_MAPPING and add the new one - // We need a write lock to do that - { - let mut map = WAPC_POLICY_MAPPING - .write() - .expect("cannot get write access to WAPC_POLICY_MAPPING"); - let policy = map.remove(&old_wapc_host_id).ok_or_else(|| { - anyhow!("cannot find old waPC policy with id {}", old_wapc_host_id) - })?; - map.insert(new_wapc_host.id(), policy); - } + let eval_ctx = get_eval_ctx(old_wapc_host_id); + unregister_policy(old_wapc_host_id); + register_policy( + new_wapc_host.id(), + worker_id, + Arc::new(RwLock::new(eval_ctx)), + ); self.wapc_host = new_wapc_host; @@ -94,13 +95,6 @@ impl WapcStack { impl Drop for WapcStack { fn drop(&mut self) { // ensure we clean this entry from the WAPC_POLICY_MAPPING mapping - match WAPC_POLICY_MAPPING.write() { - Ok(mut map) => { - map.remove(&self.wapc_host.id()); - } - Err(_) => { - warn!("cannot cleanup policy from WAPC_POLICY_MAPPING"); - } - } + unregister_policy(self.wapc_host.id()); } }