diff --git a/src/auth_action.rs b/src/auth_action.rs index 94af576c..5b1dbd9a 100644 --- a/src/auth_action.rs +++ b/src/auth_action.rs @@ -1,7 +1,8 @@ use crate::configuration::{Action, FailureMode, Service}; -use crate::data::Predicate; -use crate::service::GrpcService; -use log::error; +use crate::data::{store_metadata, Predicate, PredicateVec}; +use crate::envoy::{CheckResponse, CheckResponse_oneof_http_response, HeaderValueOption}; +use crate::service::{GrpcErrResponse, GrpcService, Headers}; +use log::debug; use std::rc::Rc; #[derive(Debug)] @@ -34,28 +35,89 @@ impl AuthAction { } pub fn conditions_apply(&self) -> bool { - let predicates = &self.predicates; - predicates.is_empty() - || predicates.iter().all(|predicate| match predicate.test() { - Ok(b) => b, - Err(err) => { - error!("Failed to evaluate {:?}: {}", predicate, err); - panic!("Err out of this!") - } - }) + self.predicates.apply() } pub fn get_failure_mode(&self) -> FailureMode { self.grpc_service.get_failure_mode() } + + pub fn process_response( + &self, + check_response: CheckResponse, + ) -> Result { + //todo(adam-cattermole):hostvar resolver? + // store dynamic metadata in filter state + debug!("process_response(auth): store_metadata"); + store_metadata(check_response.get_dynamic_metadata()); + + match check_response.http_response { + None => { + debug!("process_response(auth): received no http_response"); + match self.get_failure_mode() { + FailureMode::Deny => Err(GrpcErrResponse::new_internal_server_error()), + FailureMode::Allow => { + debug!("process_response(auth): continuing as FailureMode Allow"); + Ok(Vec::default()) + } + } + } + Some(CheckResponse_oneof_http_response::denied_response(denied_response)) => { + debug!("process_response(auth): received DeniedHttpResponse"); + let status_code = denied_response.get_status().get_code(); + let response_headers = Self::get_header_vec(denied_response.get_headers()); + Err(GrpcErrResponse::new( + status_code as u32, + response_headers, + denied_response.body, + )) + } + Some(CheckResponse_oneof_http_response::ok_response(ok_response)) => { + debug!("process_response(auth): received OkHttpResponse"); + + if !ok_response.get_response_headers_to_add().is_empty() { + panic!("process_response(auth): response contained response_headers_to_add which is unsupported!") + } + if !ok_response.get_headers_to_remove().is_empty() { + panic!("process_response(auth): response contained headers_to_remove which is unsupported!") + } + if !ok_response.get_query_parameters_to_set().is_empty() { + panic!("process_response(auth): response contained query_parameters_to_set which is unsupported!") + } + if !ok_response.get_query_parameters_to_remove().is_empty() { + panic!("process_response(auth): response contained query_parameters_to_remove which is unsupported!") + } + Ok(Self::get_header_vec(ok_response.get_headers())) + } + } + } + + fn get_header_vec(headers: &[HeaderValueOption]) -> Headers { + headers + .iter() + .map(|header| { + let hv = header.get_header(); + (hv.key.to_owned(), hv.value.to_owned()) + }) + .collect() + } } #[cfg(test)] mod test { use super::*; use crate::configuration::{Action, FailureMode, Service, ServiceType, Timeout}; + use crate::envoy::{DeniedHttpResponse, HeaderValue, HttpStatus, OkHttpResponse, StatusCode}; + use protobuf::RepeatedField; fn build_auth_action_with_predicates(predicates: Vec) -> AuthAction { + build_auth_action_with_predicates_and_failure_mode(predicates, FailureMode::default()) + } + + fn build_auth_action_with_predicates_and_failure_mode( + predicates: Vec, + failure_mode: FailureMode, + ) -> AuthAction { let action = Action { service: "some_service".into(), scope: "some_scope".into(), @@ -66,7 +128,7 @@ mod test { let service = Service { service_type: ServiceType::Auth, endpoint: "some_endpoint".into(), - failure_mode: FailureMode::default(), + failure_mode, timeout: Timeout::default(), }; @@ -74,6 +136,56 @@ mod test { .expect("action building failed. Maybe predicates compilation?") } + fn build_check_response( + status: StatusCode, + headers: Option>, + body: Option, + ) -> CheckResponse { + let mut response = CheckResponse::new(); + match status { + StatusCode::OK => { + let mut ok_http_response = OkHttpResponse::new(); + if let Some(header_list) = headers { + ok_http_response.set_headers(build_headers(header_list)) + } + response.set_ok_response(ok_http_response); + } + StatusCode::Forbidden => { + let mut http_status = HttpStatus::new(); + http_status.set_code(status); + + let mut denied_http_response = DeniedHttpResponse::new(); + denied_http_response.set_status(http_status); + if let Some(header_list) = headers { + denied_http_response.set_headers(build_headers(header_list)); + } + denied_http_response.set_body(body.unwrap_or_default()); + response.set_denied_response(denied_http_response); + } + _ => { + // assume any other code is for error state + } + }; + response + } + + fn build_headers(headers: Vec<(&str, &str)>) -> RepeatedField { + headers + .into_iter() + .map(|(key, value)| { + let header_value = { + let mut hv = HeaderValue::new(); + hv.set_key(key.to_string()); + hv.set_value(value.to_string()); + hv + }; + let mut header_option = HeaderValueOption::new(); + header_option.set_header(header_value); + header_option + }) + .collect::>() + } + #[test] fn empty_predicates_do_apply() { let auth_action = build_auth_action_with_predicates(Vec::default()); @@ -108,4 +220,99 @@ mod test { ]); auth_action.conditions_apply(); } + + #[test] + fn process_ok_response() { + let auth_action = build_auth_action_with_predicates(Vec::default()); + let ok_response_without_headers = build_check_response(StatusCode::OK, None, None); + let result = auth_action.process_response(ok_response_without_headers); + assert!(result.is_ok()); + + let headers = result.expect("is ok"); + assert!(headers.is_empty()); + + let ok_response_with_header = + build_check_response(StatusCode::OK, Some(vec![("my_header", "my_value")]), None); + let result = auth_action.process_response(ok_response_with_header); + assert!(result.is_ok()); + + let headers = result.expect("is ok"); + assert!(!headers.is_empty()); + + assert_eq!( + headers[0], + ("my_header".to_string(), "my_value".to_string()) + ); + } + + #[test] + fn process_denied_response() { + let headers = vec![ + ("www-authenticate", "APIKEY realm=\"api-key-users\""), + ("x-ext-auth-reason", "credential not found"), + ]; + let auth_action = build_auth_action_with_predicates(Vec::default()); + let denied_response_empty = build_check_response(StatusCode::Forbidden, None, None); + let result = auth_action.process_response(denied_response_empty); + assert!(result.is_err()); + + let grpc_err_response = result.expect_err("is err"); + assert_eq!( + grpc_err_response.status_code(), + StatusCode::Forbidden as u32 + ); + assert!(grpc_err_response.headers().is_empty()); + assert_eq!(grpc_err_response.body(), String::default()); + + let denied_response_content = build_check_response( + StatusCode::Forbidden, + Some(headers.clone()), + Some("my_body".to_string()), + ); + let result = auth_action.process_response(denied_response_content); + assert!(result.is_err()); + + let grpc_err_response = result.expect_err("is err"); + assert_eq!( + grpc_err_response.status_code(), + StatusCode::Forbidden as u32 + ); + + let response_headers = grpc_err_response.headers(); + headers.iter().zip(response_headers.iter()).for_each( + |((header_one, value_one), (header_two, value_two))| { + assert_eq!(header_one, header_two); + assert_eq!(value_one, value_two); + }, + ); + + assert_eq!(grpc_err_response.body(), "my_body"); + } + + #[test] + fn process_error_response() { + let auth_action = + build_auth_action_with_predicates_and_failure_mode(Vec::default(), FailureMode::Deny); + let error_response = build_check_response(StatusCode::InternalServerError, None, None); + let result = auth_action.process_response(error_response); + assert!(result.is_err()); + + let grpc_err_response = result.expect_err("is err"); + assert_eq!( + grpc_err_response.status_code(), + StatusCode::InternalServerError as u32 + ); + + assert!(grpc_err_response.headers().is_empty()); + assert_eq!(grpc_err_response.body(), "Internal Server Error.\n"); + + let auth_action = + build_auth_action_with_predicates_and_failure_mode(Vec::default(), FailureMode::Allow); + let error_response = build_check_response(StatusCode::InternalServerError, None, None); + let result = auth_action.process_response(error_response); + assert!(result.is_ok()); + + let headers = result.expect("is ok"); + assert!(headers.is_empty()); + } } diff --git a/src/data/attribute.rs b/src/data/attribute.rs index a3a1d230..82bb0512 100644 --- a/src/data/attribute.rs +++ b/src/data/attribute.rs @@ -2,7 +2,6 @@ use crate::data::PropertyPath; use chrono::{DateTime, FixedOffset}; use log::{debug, error, warn}; use protobuf::well_known_types::Struct; -use proxy_wasm::hostcalls; use serde_json::Value; pub const KUADRANT_NAMESPACE: &str = "kuadrant"; @@ -120,7 +119,7 @@ where } pub fn set_attribute(attr: &str, value: &[u8]) { - match hostcalls::set_property(PropertyPath::from(attr).tokens(), Some(value)) { + match crate::data::property::set_property(PropertyPath::from(attr), Some(value)) { Ok(_) => (), Err(_) => error!("set_attribute: failed to set property {attr}"), }; diff --git a/src/data/cel.rs b/src/data/cel.rs index 6ef63020..66d80f25 100644 --- a/src/data/cel.rs +++ b/src/data/cel.rs @@ -7,7 +7,7 @@ use cel_parser::{parse, Expression as CelExpression, Member, ParseError}; use chrono::{DateTime, FixedOffset}; #[cfg(feature = "debug-host-behaviour")] use log::debug; -use log::warn; +use log::{error, warn}; use proxy_wasm::types::{Bytes, Status}; use serde_json::Value as JsonValue; use std::borrow::Cow; @@ -235,6 +235,23 @@ impl Predicate { } } +pub trait PredicateVec { + fn apply(&self) -> bool; +} + +impl PredicateVec for Vec { + fn apply(&self) -> bool { + self.is_empty() + || self.iter().all(|predicate| match predicate.test() { + Ok(b) => b, + Err(err) => { + error!("Failed to evaluate {:?}: {}", predicate, err); + panic!("Err out of this!") + } + }) + } +} + pub struct Attribute { path: Path, cel_type: Option, diff --git a/src/data/mod.rs b/src/data/mod.rs index 9b7dd28b..b780fe75 100644 --- a/src/data/mod.rs +++ b/src/data/mod.rs @@ -10,5 +10,6 @@ pub use cel::debug_all_well_known_attributes; pub use cel::Expression; pub use cel::Predicate; +pub use cel::PredicateVec; pub use property::Path as PropertyPath; diff --git a/src/data/property.rs b/src/data/property.rs index bcfdc278..17321c52 100644 --- a/src/data/property.rs +++ b/src/data/property.rs @@ -55,6 +55,14 @@ pub fn host_get_map(path: &Path) -> Result, String> { } } +#[cfg(test)] +pub fn host_set_property(path: Path, value: Option<&[u8]>) -> Result<(), Status> { + debug!("set_property: {:?}", path); + let data = value.map(|bytes| bytes.to_vec()).unwrap_or_default(); + test::TEST_PROPERTY_VALUE.set(Some((path, data))); + Ok(()) +} + #[cfg(not(test))] pub fn host_get_map(path: &Path) -> Result, String> { match *path.tokens() { @@ -77,6 +85,12 @@ pub(super) fn host_get_property(path: &Path) -> Result>, Status> proxy_wasm::hostcalls::get_property(path.tokens()) } +#[cfg(not(test))] +pub(super) fn host_set_property(path: Path, value: Option<&[u8]>) -> Result<(), Status> { + debug!("set_property: {:?}", path); + proxy_wasm::hostcalls::set_property(path.tokens(), value) +} + pub(super) fn get_property(path: &Path) -> Result>, Status> { match *path.tokens() { ["source", "remote_address"] => remote_address(), @@ -85,6 +99,10 @@ pub(super) fn get_property(path: &Path) -> Result>, Status> { } } +pub(super) fn set_property(path: Path, value: Option<&[u8]>) -> Result<(), Status> { + host_set_property(path, value) +} + #[derive(Clone, Hash, PartialEq, Eq)] pub struct Path { tokens: Vec, diff --git a/src/envoy/mod.rs b/src/envoy/mod.rs index ddbcf0e7..8dd169e7 100644 --- a/src/envoy/mod.rs +++ b/src/envoy/mod.rs @@ -36,7 +36,7 @@ pub use { AttributeContext, AttributeContext_HttpRequest, AttributeContext_Peer, AttributeContext_Request, }, - base::Metadata, + base::{HeaderValue, HeaderValueOption, Metadata}, external_auth::{CheckRequest, CheckResponse, CheckResponse_oneof_http_response}, http_status::StatusCode, ratelimit::{RateLimitDescriptor, RateLimitDescriptor_Entry}, @@ -44,4 +44,7 @@ pub use { }; #[cfg(test)] -pub use base::HeaderValue; +pub use { + external_auth::{DeniedHttpResponse, OkHttpResponse}, + http_status::HttpStatus, +}; diff --git a/src/filter.rs b/src/filter.rs deleted file mode 100644 index ed9c2418..00000000 --- a/src/filter.rs +++ /dev/null @@ -1,39 +0,0 @@ -pub(crate) mod http_context; -mod root_context; - -#[cfg_attr( - all( - target_arch = "wasm32", - target_vendor = "unknown", - target_os = "unknown" - ), - export_name = "_start" -)] -#[cfg_attr( - not(all( - target_arch = "wasm32", - target_vendor = "unknown", - target_os = "unknown" - )), - allow(dead_code) -)] -// This is a C interface, so make it explicit in the fn signature (and avoid mangling) -extern "C" fn start() { - use log::info; - use proxy_wasm::traits::RootContext; - use proxy_wasm::types::LogLevel; - use root_context::FilterRoot; - - proxy_wasm::set_log_level(LogLevel::Trace); - std::panic::set_hook(Box::new(|panic_info| { - proxy_wasm::hostcalls::log(LogLevel::Critical, &panic_info.to_string()) - .expect("failed to log panic_info"); - })); - proxy_wasm::set_root_context(|context_id| -> Box { - info!("#{} set_root_context", context_id); - Box::new(FilterRoot { - context_id, - config: Default::default(), - }) - }); -} diff --git a/src/filter/http_context.rs b/src/filter/http_context.rs deleted file mode 100644 index 787df8e3..00000000 --- a/src/filter/http_context.rs +++ /dev/null @@ -1,176 +0,0 @@ -use crate::configuration::FailureMode; -#[cfg(feature = "debug-host-behaviour")] -use crate::data; -use crate::operation_dispatcher::{OperationDispatcher, OperationError}; -use crate::runtime_action_set::RuntimeActionSet; -use crate::runtime_config::RuntimeConfig; -use crate::service::GrpcService; -use log::{debug, warn}; -use proxy_wasm::traits::{Context, HttpContext}; -use proxy_wasm::types::Action; -use std::cell::RefCell; -use std::rc::Rc; - -pub struct Filter { - pub context_id: u32, - pub config: Rc, - pub response_headers_to_add: Vec<(String, String)>, - pub operation_dispatcher: RefCell, -} - -impl Filter { - fn request_authority(&self) -> String { - match self.get_http_request_header(":authority") { - None => { - warn!(":authority header not found"); - String::new() - } - Some(host) => { - let split_host = host.split(':').collect::>(); - split_host[0].to_owned() - } - } - } - - #[allow(unknown_lints, clippy::manual_inspect)] - fn process_action_sets(&self, m_set_list: &[Rc]) -> Action { - if let Some(m_set) = m_set_list.iter().find(|m_set| m_set.conditions_apply()) { - debug!("#{} action_set selected {}", self.context_id, m_set.name); - //debug!("#{} runtime action_set {:#?}", self.context_id, m_set); - self.operation_dispatcher - .borrow_mut() - .build_operations(&m_set.runtime_actions) - } else { - debug!( - "#{} process_action_sets: no action_set with conditions applies", - self.context_id - ); - return Action::Continue; - } - - match self.operation_dispatcher.borrow_mut().next() { - Ok(Some(op)) => match op.get_result() { - Ok(call_id) => { - debug!("#{} initiated gRPC call (id# {})", self.context_id, call_id); - Action::Pause - } - Err(e) => { - warn!("gRPC call failed! {e:?}"); - if let FailureMode::Deny = op.get_failure_mode() { - self.send_http_response(500, vec![], Some(b"Internal Server Error.\n")) - } - Action::Continue - } - }, - Ok(None) => { - Action::Continue // No operations left to perform - } - Err(OperationError { - failure_mode: FailureMode::Deny, - status, - }) => { - warn!("OperationError Status: {status:?}"); - self.send_http_response(500, vec![], Some(b"Internal Server Error.\n")); - Action::Continue - } - Err(OperationError { - failure_mode: FailureMode::Allow, - status, - }) => { - warn!("OperationError Status: {status:?}"); - Action::Continue - } - } - } - - fn process_next_op(&self) { - match self.operation_dispatcher.borrow_mut().next() { - Ok(some_op) => { - if some_op.is_none() { - // No more operations left in queue, resuming - self.resume_http_request(); - } - } - Err(op_err) => { - // If desired, we could check the error status. - GrpcService::handle_error_on_grpc_response(op_err.failure_mode); - } - } - } -} - -impl HttpContext for Filter { - fn on_http_request_headers(&mut self, _: usize, _: bool) -> Action { - debug!("#{} on_http_request_headers", self.context_id); - - #[cfg(feature = "debug-host-behaviour")] - data::debug_all_well_known_attributes(); - - match self - .config - .index - .get_longest_match_action_sets(self.request_authority().as_str()) - { - None => { - debug!( - "#{} allowing request to pass because zero descriptors generated", - self.context_id - ); - Action::Continue - } - Some(m_sets) => self.process_action_sets(m_sets), - } - } - - fn on_http_response_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action { - debug!("#{} on_http_response_headers", self.context_id); - for (name, value) in &self.response_headers_to_add { - self.add_http_response_header(name, value); - } - Action::Continue - } - - fn on_log(&mut self) { - debug!("#{} completed.", self.context_id); - } -} - -impl Context for Filter { - fn on_grpc_call_response(&mut self, token_id: u32, status_code: u32, resp_size: usize) { - debug!( - "#{} on_grpc_call_response: received gRPC call response: token: {token_id}, status: {status_code}", - self.context_id - ); - - let op_res = self - .operation_dispatcher - .borrow() - .get_waiting_operation(token_id); - - match op_res { - Ok(operation) => { - match GrpcService::process_grpc_response(Rc::clone(&operation), resp_size) { - Ok(result) => { - // add the response headers - self.response_headers_to_add.extend(result.response_headers); - // call the next op - self.process_next_op(); - } - Err(_) => { - match operation.get_failure_mode() { - FailureMode::Deny => {} - FailureMode::Allow => { - // call the next op - self.process_next_op(); - } - } - } - } - } - Err(e) => { - warn!("No Operation found with token_id: {token_id}"); - GrpcService::handle_error_on_grpc_response(e.failure_mode); - } - } - } -} diff --git a/src/filter/kuadrant_filter.rs b/src/filter/kuadrant_filter.rs new file mode 100644 index 00000000..14ec6ba1 --- /dev/null +++ b/src/filter/kuadrant_filter.rs @@ -0,0 +1,187 @@ +use crate::action_set_index::ActionSetIndex; +use crate::filter::operations::{ + GrpcMessageReceiverOperation, GrpcMessageSenderOperation, HeadersOperation, Operation, +}; +use crate::runtime_action_set::RuntimeActionSet; +use crate::service::{GrpcErrResponse, GrpcRequest, HeaderResolver}; +use log::{debug, warn}; +use proxy_wasm::traits::{Context, HttpContext}; +use proxy_wasm::types::{Action, Status}; +use std::mem; +use std::rc::Rc; + +pub(crate) struct KuadrantFilter { + context_id: u32, + index: Rc, + header_resolver: Rc, + + grpc_message_receiver_operation: Option, + headers_operations: Vec, +} + +impl Context for KuadrantFilter { + fn on_grpc_call_response(&mut self, token_id: u32, status_code: u32, resp_size: usize) { + debug!( + "#{} on_grpc_call_response: received gRPC call response: token: {token_id}, status: {status_code}", + self.context_id + ); + let receiver = mem::take(&mut self.grpc_message_receiver_operation) + .expect("We need an operation pending a gRPC response"); + + let mut ops = Vec::new(); + + if status_code != Status::Ok as u32 { + ops.push(receiver.fail()); + } else if let Some(response_body) = self.get_grpc_call_response_body(0, resp_size) { + ops.extend(receiver.digest_grpc_response(&response_body)); + } else { + ops.push(receiver.fail()); + } + + ops.into_iter().for_each(|op| { + self.handle_operation(op); + }) + } +} + +impl HttpContext for KuadrantFilter { + fn on_http_request_headers(&mut self, _: usize, _: bool) -> Action { + debug!("#{} on_http_request_headers", self.context_id); + + #[cfg(feature = "debug-host-behaviour")] + crate::data::debug_all_well_known_attributes(); + + if let Some(action_sets) = self + .index + .get_longest_match_action_sets(self.request_authority().as_ref()) + { + if let Some(action_set) = action_sets + .iter() + .find(|action_set| action_set.conditions_apply(/* self */)) + { + debug!( + "#{} action_set selected {}", + self.context_id, action_set.name + ); + return self.start_flow(Rc::clone(action_set)); + } + } + Action::Continue + } + + fn on_http_response_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action { + debug!("#{} on_http_response_headers", self.context_id); + let headers_operations = mem::take(&mut self.headers_operations); + for op in headers_operations { + for (header, value) in &op.headers() { + self.add_http_response_header(header, value) + } + } + Action::Continue + } +} + +impl KuadrantFilter { + fn start_flow(&mut self, action_set: Rc) -> Action { + let grpc_request = action_set.find_first_grpc_request(); + let op = match grpc_request { + None => Operation::Done(), + Some(indexed_req) => { + Operation::SendGrpcRequest(GrpcMessageSenderOperation::new(action_set, indexed_req)) + } + }; + self.handle_operation(op) + } + + fn handle_operation(&mut self, operation: Operation) -> Action { + match operation { + Operation::SendGrpcRequest(sender_op) => { + debug!("handle_operation: SendGrpcRequest"); + let next_op = { + let (req, receiver_op) = sender_op.build_receiver_operation(); + match self.send_grpc_request(req) { + Ok(_token) => Operation::AwaitGrpcResponse(receiver_op), + Err(status) => { + debug!("handle_operation: failed to send grpc request `{status:?}`"); + receiver_op.fail() + } + } + }; + self.handle_operation(next_op) + } + Operation::AwaitGrpcResponse(receiver_op) => { + debug!("handle_operation: AwaitGrpcResponse"); + self.grpc_message_receiver_operation = Some(receiver_op); + Action::Pause + } + Operation::AddHeaders(header_op) => { + debug!("handle_operation: AddHeaders"); + self.headers_operations.push(header_op); + Action::Continue + } + Operation::Die(die_op) => { + debug!("handle_operation: Die"); + self.die(die_op); + Action::Continue + } + Operation::Done() => { + debug!("handle_operation: Done"); + self.resume_http_request(); + Action::Continue + } + } + } + + fn die(&mut self, die: GrpcErrResponse) { + self.send_http_response( + die.status_code(), + die.headers(), + Some(die.body().as_bytes()), + ); + } + + fn request_authority(&self) -> String { + match self.get_http_request_header(":authority") { + None => { + warn!(":authority header not found"); + String::new() + } + Some(host) => { + let split_host = host.split(':').collect::>(); + split_host[0].to_owned() + } + } + } + + fn send_grpc_request(&self, req: GrpcRequest) -> Result { + let headers = self + .header_resolver + .get_with_ctx(self) + .iter() + .map(|(header, value)| (*header, value.as_slice())) + .collect(); + + self.dispatch_grpc_call( + req.upstream_name(), + req.service_name(), + req.method_name(), + headers, + req.message(), + req.timeout(), + ) + } + + pub fn new( + context_id: u32, + index: Rc, + header_resolver: Rc, + ) -> Self { + Self { + context_id, + index, + header_resolver, + grpc_message_receiver_operation: None, + headers_operations: Vec::default(), + } + } +} diff --git a/src/filter/mod.rs b/src/filter/mod.rs new file mode 100644 index 00000000..089a1048 --- /dev/null +++ b/src/filter/mod.rs @@ -0,0 +1,3 @@ +pub(crate) mod kuadrant_filter; +pub(crate) mod operations; +pub(crate) mod root_context; diff --git a/src/filter/operations.rs b/src/filter/operations.rs new file mode 100644 index 00000000..2d6d0ad0 --- /dev/null +++ b/src/filter/operations.rs @@ -0,0 +1,107 @@ +use crate::configuration::FailureMode; +use crate::filter::operations::Operation::SendGrpcRequest; +use crate::runtime_action_set::RuntimeActionSet; +use crate::service::{GrpcErrResponse, GrpcRequest, Headers, IndexedGrpcRequest}; +use std::rc::Rc; + +pub enum Operation { + SendGrpcRequest(GrpcMessageSenderOperation), + AwaitGrpcResponse(GrpcMessageReceiverOperation), + AddHeaders(HeadersOperation), + Die(GrpcErrResponse), + // Done indicates that we have no more operations and can resume the http request flow + Done(), +} + +pub struct GrpcMessageSenderOperation { + runtime_action_set: Rc, + grpc_request: IndexedGrpcRequest, +} + +impl GrpcMessageSenderOperation { + pub fn new( + runtime_action_set: Rc, + indexed_request: IndexedGrpcRequest, + ) -> Self { + Self { + runtime_action_set, + grpc_request: indexed_request, + } + } + + pub fn build_receiver_operation(self) -> (GrpcRequest, GrpcMessageReceiverOperation) { + let index = self.grpc_request.index(); + ( + self.grpc_request.request(), + GrpcMessageReceiverOperation::new(self.runtime_action_set, index), + ) + } +} + +pub struct GrpcMessageReceiverOperation { + runtime_action_set: Rc, + current_index: usize, +} + +impl GrpcMessageReceiverOperation { + pub fn new(runtime_action_set: Rc, current_index: usize) -> Self { + Self { + runtime_action_set, + current_index, + } + } + + pub fn digest_grpc_response(self, msg: &[u8]) -> Vec { + let result = self + .runtime_action_set + .process_grpc_response(self.current_index, msg); + + match result { + Ok((next_msg, headers)) => { + let mut operations = Vec::new(); + if !headers.is_empty() { + operations.push(Operation::AddHeaders(HeadersOperation::new(headers))) + } + operations.push(match next_msg { + None => Operation::Done(), + Some(indexed_req) => SendGrpcRequest(GrpcMessageSenderOperation::new( + self.runtime_action_set, + indexed_req, + )), + }); + operations + } + Err(grpc_err_resp) => vec![Operation::Die(grpc_err_resp)], + } + } + + pub fn fail(self) -> Operation { + match self.runtime_action_set.runtime_actions[self.current_index].get_failure_mode() { + FailureMode::Deny => Operation::Die(GrpcErrResponse::new_internal_server_error()), + FailureMode::Allow => match self + .runtime_action_set + .find_next_grpc_request(self.current_index + 1) + { + None => Operation::Done(), + Some(indexed_req) => Operation::SendGrpcRequest(GrpcMessageSenderOperation::new( + self.runtime_action_set, + indexed_req, + )), + }, + } + } +} + +pub struct HeadersOperation { + headers: Headers, +} + +impl HeadersOperation { + pub fn new(headers: Headers) -> Self { + Self { headers } + } + + pub fn headers(self) -> Headers { + self.headers + } +} diff --git a/src/filter/root_context.rs b/src/filter/root_context.rs index 4716387b..714d3ac4 100644 --- a/src/filter/root_context.rs +++ b/src/filter/root_context.rs @@ -1,7 +1,6 @@ +use crate::action_set_index::ActionSetIndex; use crate::configuration::PluginConfiguration; -use crate::filter::http_context::Filter; -use crate::operation_dispatcher::OperationDispatcher; -use crate::runtime_config::RuntimeConfig; +use crate::filter::kuadrant_filter::KuadrantFilter; use crate::service::HeaderResolver; use const_format::formatcp; use log::{debug, error, info}; @@ -17,7 +16,7 @@ const WASM_SHIM_HEADER: &str = "Kuadrant wasm module"; pub struct FilterRoot { pub context_id: u32, - pub config: Rc, + pub action_set_index: Rc, } impl RootContext for FilterRoot { @@ -36,12 +35,11 @@ impl RootContext for FilterRoot { fn create_http_context(&self, context_id: u32) -> Option> { debug!("#{} create_http_context", context_id); let header_resolver = Rc::new(HeaderResolver::new()); - Some(Box::new(Filter { + Some(Box::new(KuadrantFilter::new( context_id, - config: Rc::clone(&self.config), - response_headers_to_add: Vec::default(), - operation_dispatcher: OperationDispatcher::new(header_resolver).into(), - })) + Rc::clone(&self.action_set_index), + header_resolver, + ))) } fn on_configure(&mut self, _config_size: usize) -> bool { @@ -53,15 +51,15 @@ impl RootContext for FilterRoot { match serde_json::from_slice::(&configuration) { Ok(config) => { info!("plugin config parsed: {:?}", config); - let runtime_config = - match >::try_into(config) { + let action_set_index = + match >::try_into(config) { Ok(cfg) => cfg, Err(err) => { error!("failed to compile plugin config: {}", err); return false; } }; - self.config = Rc::new(runtime_config); + self.action_set_index = Rc::new(action_set_index); } Err(e) => { error!("failed to parse plugin config: {}", e); diff --git a/src/lib.rs b/src/lib.rs index 2062ebd1..227ef353 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,13 +6,49 @@ mod data; mod envoy; mod filter; mod glob; -mod operation_dispatcher; mod ratelimit_action; mod runtime_action; mod runtime_action_set; mod runtime_config; mod service; +#[cfg_attr( + all( + target_arch = "wasm32", + target_vendor = "unknown", + target_os = "unknown" + ), + export_name = "_start" +)] +#[cfg_attr( + not(all( + target_arch = "wasm32", + target_vendor = "unknown", + target_os = "unknown" + )), + allow(dead_code) +)] +// This is a C interface, so make it explicit in the fn signature (and avoid mangling) +extern "C" fn start() { + use crate::filter::root_context::FilterRoot; + use log::info; + use proxy_wasm::traits::RootContext; + use proxy_wasm::types::LogLevel; + + proxy_wasm::set_log_level(LogLevel::Trace); + std::panic::set_hook(Box::new(|panic_info| { + proxy_wasm::hostcalls::log(LogLevel::Critical, &panic_info.to_string()) + .expect("failed to log panic_info"); + })); + proxy_wasm::set_root_context(|context_id| -> Box { + info!("#{} set_root_context", context_id); + Box::new(FilterRoot { + context_id, + action_set_index: Default::default(), + }) + }); +} + #[cfg(test)] mod tests { use crate::envoy::{HeaderValue, RateLimitResponse, RateLimitResponse_Code}; diff --git a/src/operation_dispatcher.rs b/src/operation_dispatcher.rs deleted file mode 100644 index 156b08f4..00000000 --- a/src/operation_dispatcher.rs +++ /dev/null @@ -1,548 +0,0 @@ -use crate::configuration::{FailureMode, ServiceType}; -use crate::runtime_action::RuntimeAction; -use crate::service::grpc_message::GrpcMessageRequest; -use crate::service::{ - GetMapValuesBytesFn, GrpcCallFn, GrpcMessageBuildFn, GrpcServiceHandler, HeaderResolver, -}; -use log::{debug, error}; -use proxy_wasm::hostcalls; -use proxy_wasm::types::{Bytes, MapType, Status}; -use std::cell::RefCell; -use std::collections::HashMap; -use std::fmt; -use std::rc::Rc; -use std::time::Duration; - -#[derive(PartialEq, Debug, Clone, Copy)] -pub(crate) enum State { - Pending, - Waiting, - Done, -} - -impl State { - fn next(&mut self) { - match self { - State::Pending => *self = State::Waiting, - State::Waiting => *self = State::Done, - _ => {} - } - } - - fn done(&mut self) { - *self = State::Done - } -} - -#[derive(Debug)] -pub(crate) struct Operation { - state: RefCell, - result: RefCell>, - action: Rc, - service_handler: GrpcServiceHandler, - grpc_call_fn: GrpcCallFn, - get_map_values_bytes_fn: GetMapValuesBytesFn, - grpc_message_build_fn: GrpcMessageBuildFn, - conditions_apply_fn: ConditionsApplyFn, -} - -impl Operation { - pub fn new(action: Rc, service_handler: GrpcServiceHandler) -> Self { - Self { - state: RefCell::new(State::Pending), - result: RefCell::new(Ok(0)), // Heuristics: zero represents that it's not been triggered, following `hostcalls` example - action, - service_handler, - grpc_call_fn, - get_map_values_bytes_fn, - grpc_message_build_fn, - conditions_apply_fn, - } - } - - fn trigger(&self) -> Result { - if let Some(message) = (self.grpc_message_build_fn)(&self.action) { - let res = self.service_handler.send( - self.get_map_values_bytes_fn, - self.grpc_call_fn, - message, - self.action.get_timeout(), - ); - match res { - Ok(token_id) => self.set_result(Ok(token_id)), - Err(status) => { - self.set_result(Err(OperationError::new(status, self.get_failure_mode()))) - } - } - self.next_state(); - self.get_result() - } else { - self.done(); - self.get_result() - } - } - - fn next_state(&self) { - self.state.borrow_mut().next() - } - - fn done(&self) { - self.state.borrow_mut().done() - } - - pub fn get_state(&self) -> State { - *self.state.borrow() - } - - pub fn get_result(&self) -> Result { - *self.result.borrow() - } - - fn set_result(&self, result: Result) { - *self.result.borrow_mut() = result; - } - - pub fn get_service_type(&self) -> ServiceType { - self.action.get_service_type() - } - - pub fn get_failure_mode(&self) -> FailureMode { - self.action.get_failure_mode() - } -} - -#[derive(Copy, Clone, Debug, PartialEq)] -pub struct OperationError { - pub status: Status, - pub failure_mode: FailureMode, -} - -impl OperationError { - fn new(status: Status, failure_mode: FailureMode) -> Self { - Self { - status, - failure_mode, - } - } -} - -impl fmt::Display for OperationError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self.status { - Status::ParseFailure => { - write!(f, "Error parsing configuration.") - } - _ => { - write!(f, "Error triggering the operation. {:?}", self.status) - } - } - } -} - -pub struct OperationDispatcher { - operations: Vec>, - waiting_operations: HashMap>, - header_resolver: Rc, -} - -impl OperationDispatcher { - pub fn new(header_resolver: Rc) -> Self { - Self { - operations: vec![], - waiting_operations: HashMap::new(), - header_resolver: Rc::clone(&header_resolver), - } - } - - pub fn get_waiting_operation(&self, token_id: u32) -> Result, OperationError> { - let op = self.waiting_operations.get(&token_id); - match op { - Some(op) => { - op.next_state(); - Ok(op.clone()) - } - None => Err(OperationError::new( - Status::NotFound, - FailureMode::default(), - )), - } - } - - pub fn build_operations(&mut self, actions: &[Rc]) { - let mut operations: Vec> = vec![]; - for action in actions.iter() { - operations.push(Rc::new(Operation::new( - Rc::clone(action), - GrpcServiceHandler::new(action.grpc_service(), Rc::clone(&self.header_resolver)), - ))); - } - self.push_operations(operations); - } - - pub fn push_operations(&mut self, operations: Vec>) { - self.operations.extend(operations); - } - - pub fn next(&mut self) -> Result>, OperationError> { - if let Some((i, operation)) = self.operations.iter_mut().enumerate().next() { - match operation.get_state() { - State::Pending => { - if (operation.conditions_apply_fn)(&operation.action) { - match operation.trigger() { - Ok(token_id) => { - match operation.get_state() { - State::Pending => { - panic!("Operation dispatcher reached an undefined state"); - } - State::Waiting => { - // We index only if it was just transitioned to Waiting after triggering - self.waiting_operations.insert(token_id, operation.clone()); - // TODO(didierofrivia): Decide on indexing the failed operations. - Ok(Some(operation.clone())) - } - State::Done => self.next(), - } - } - Err(err) => { - error!("{err:?}"); - Err(err) - } - } - } else { - debug!("actions conditions do not apply, skipping"); - self.operations.remove(i); - self.next() - } - } - State::Waiting => { - operation.next_state(); - Ok(Some(operation.clone())) - } - State::Done => { - if let Ok(token_id) = operation.get_result() { - self.waiting_operations.remove(&token_id); - } // If result was Err, means the operation wasn't indexed - self.operations.remove(i); - self.next() - } - } - } else { - Ok(None) - } - } - - #[cfg(test)] - pub fn default() -> Self { - OperationDispatcher { - operations: vec![], - waiting_operations: HashMap::default(), - header_resolver: Rc::new(HeaderResolver::default()), - } - } - - #[cfg(test)] - pub fn get_current_operation_state(&self) -> Option { - self.operations - .first() - .map(|operation| operation.get_state()) - } -} - -fn grpc_call_fn( - upstream_name: &str, - service_name: &str, - method_name: &str, - initial_metadata: Vec<(&str, &[u8])>, - message: Option<&[u8]>, - timeout: Duration, -) -> Result { - hostcalls::dispatch_grpc_call( - upstream_name, - service_name, - method_name, - initial_metadata, - message, - timeout, - ) -} - -fn get_map_values_bytes_fn(map_type: MapType, key: &str) -> Result, Status> { - hostcalls::get_map_value_bytes(map_type, key) -} - -fn grpc_message_build_fn(action: &RuntimeAction) -> Option { - GrpcMessageRequest::new(action) -} - -type ConditionsApplyFn = fn(action: &RuntimeAction) -> bool; - -fn conditions_apply_fn(action: &RuntimeAction) -> bool { - action.conditions_apply() -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::auth_action::AuthAction; - use crate::configuration::{Action, Service, Timeout}; - use crate::envoy::RateLimitRequest; - use crate::ratelimit_action::RateLimitAction; - use protobuf::RepeatedField; - use std::rc::Rc; - use std::time::Duration; - - fn default_grpc_call_fn_stub( - _upstream_name: &str, - _service_name: &str, - _method_name: &str, - _initial_metadata: Vec<(&str, &[u8])>, - _message: Option<&[u8]>, - _timeout: Duration, - ) -> Result { - Ok(200) - } - - fn get_map_values_bytes_fn_stub( - _map_type: MapType, - _key: &str, - ) -> Result, Status> { - Ok(Some(Vec::new())) - } - - fn grpc_message_build_fn_stub(_action: &RuntimeAction) -> Option { - Some(GrpcMessageRequest::RateLimit(build_message())) - } - - fn build_grpc_service_handler() -> GrpcServiceHandler { - GrpcServiceHandler::new(Rc::new(Default::default()), Default::default()) - } - - fn conditions_apply_fn_stub(_action: &RuntimeAction) -> bool { - true - } - - fn build_message() -> RateLimitRequest { - RateLimitRequest { - domain: "example.org".to_string(), - descriptors: RepeatedField::new(), - hits_addend: 1, - unknown_fields: Default::default(), - cached_size: Default::default(), - } - } - - fn build_auth_grpc_action() -> RuntimeAction { - let service = Service { - service_type: ServiceType::Auth, - endpoint: "local".to_string(), - failure_mode: FailureMode::Deny, - timeout: Timeout(Duration::from_millis(42)), - }; - let action = Action { - service: "local".to_string(), - scope: "".to_string(), - predicates: vec![], - data: vec![], - }; - RuntimeAction::Auth( - AuthAction::new(&action, &service).expect("empty predicates should compile!"), - ) - } - - fn build_rate_limit_grpc_action() -> RuntimeAction { - let service = Service { - service_type: ServiceType::RateLimit, - endpoint: "local".to_string(), - failure_mode: FailureMode::Deny, - timeout: Timeout(Duration::from_millis(42)), - }; - let action = Action { - service: "local".to_string(), - scope: "".to_string(), - predicates: vec![], - data: vec![], - }; - RuntimeAction::RateLimit( - RateLimitAction::new(&action, &service).expect("empty predicates should compile!"), - ) - } - - fn build_operation(grpc_call_fn_stub: GrpcCallFn, action: RuntimeAction) -> Rc { - Rc::new(Operation { - state: RefCell::from(State::Pending), - result: RefCell::new(Ok(0)), - action: Rc::new(action), - service_handler: build_grpc_service_handler(), - grpc_call_fn: grpc_call_fn_stub, - get_map_values_bytes_fn: get_map_values_bytes_fn_stub, - grpc_message_build_fn: grpc_message_build_fn_stub, - conditions_apply_fn: conditions_apply_fn_stub, - }) - } - - #[test] - fn operation_getters() { - let operation = build_operation(default_grpc_call_fn_stub, build_rate_limit_grpc_action()); - - assert_eq!(operation.get_state(), State::Pending); - assert_eq!(operation.get_service_type(), ServiceType::RateLimit); - assert_eq!(operation.get_failure_mode(), FailureMode::Deny); - assert_eq!(operation.get_result(), Ok(0)); - } - - #[test] - fn operation_transition() { - let operation = build_operation(default_grpc_call_fn_stub, build_rate_limit_grpc_action()); - assert_eq!(operation.get_result(), Ok(0)); - assert_eq!(operation.get_state(), State::Pending); - let mut res = operation.trigger(); - assert_eq!(res, Ok(200)); - assert_eq!(operation.get_state(), State::Waiting); - res = operation.trigger(); - assert_eq!(res, Ok(200)); - assert_eq!(operation.get_result(), Ok(200)); - assert_eq!(operation.get_state(), State::Done); - } - - #[test] - fn operation_dispatcher_push_actions() { - let mut operation_dispatcher = OperationDispatcher::default(); - - assert_eq!(operation_dispatcher.operations.len(), 0); - let operation = build_operation(default_grpc_call_fn_stub, build_rate_limit_grpc_action()); - operation_dispatcher.push_operations(vec![operation]); - - assert_eq!(operation_dispatcher.operations.len(), 1); - } - - #[test] - fn operation_dispatcher_get_current_action_state() { - let mut operation_dispatcher = OperationDispatcher::default(); - let operation = build_operation(default_grpc_call_fn_stub, build_rate_limit_grpc_action()); - operation_dispatcher.push_operations(vec![operation]); - assert_eq!( - operation_dispatcher.get_current_operation_state(), - Some(State::Pending) - ); - } - - #[test] - fn operation_dispatcher_next() { - let mut operation_dispatcher = OperationDispatcher::default(); - - fn grpc_call_fn_stub_66( - _upstream_name: &str, - _service_name: &str, - _method_name: &str, - _initial_metadata: Vec<(&str, &[u8])>, - _message: Option<&[u8]>, - _timeout: Duration, - ) -> Result { - Ok(66) - } - - fn grpc_call_fn_stub_77( - _upstream_name: &str, - _service_name: &str, - _method_name: &str, - _initial_metadata: Vec<(&str, &[u8])>, - _message: Option<&[u8]>, - _timeout: Duration, - ) -> Result { - Ok(77) - } - - operation_dispatcher.push_operations(vec![ - build_operation(grpc_call_fn_stub_66, build_rate_limit_grpc_action()), - build_operation(grpc_call_fn_stub_77, build_auth_grpc_action()), - ]); - - assert_eq!( - operation_dispatcher.get_current_operation_state(), - Some(State::Pending) - ); - assert_eq!(operation_dispatcher.waiting_operations.len(), 0); - - let mut op = operation_dispatcher.next(); - assert_eq!( - op.clone() - .expect("ok result") - .expect("operation is some") - .get_result(), - Ok(66) - ); - assert_eq!( - op.clone() - .expect("ok result") - .expect("operation is some") - .get_service_type(), - ServiceType::RateLimit - ); - assert_eq!( - op.expect("ok result") - .expect("operation is some") - .get_state(), - State::Waiting - ); - assert_eq!(operation_dispatcher.waiting_operations.len(), 1); - - op = operation_dispatcher.next(); - assert_eq!( - op.clone() - .expect("ok result") - .expect("operation is some") - .get_result(), - Ok(66) - ); - assert_eq!( - op.expect("ok result") - .expect("operation is some") - .get_state(), - State::Done - ); - - op = operation_dispatcher.next(); - assert_eq!( - op.clone() - .expect("ok result") - .expect("operation is some") - .get_result(), - Ok(77) - ); - assert_eq!( - op.clone() - .expect("ok result") - .expect("operation is some") - .get_service_type(), - ServiceType::Auth - ); - assert_eq!( - op.expect("ok result") - .expect("operation is some") - .get_state(), - State::Waiting - ); - assert_eq!(operation_dispatcher.waiting_operations.len(), 1); - - op = operation_dispatcher.next(); - assert_eq!( - op.clone() - .expect("ok result") - .expect("operation is some") - .get_result(), - Ok(77) - ); - assert_eq!( - op.expect("ok result") - .expect("operation is some") - .get_state(), - State::Done - ); - assert_eq!(operation_dispatcher.waiting_operations.len(), 1); - - op = operation_dispatcher.next(); - assert!(op.expect("ok result").is_none()); - assert!(operation_dispatcher.get_current_operation_state().is_none()); - assert_eq!(operation_dispatcher.waiting_operations.len(), 0); - } -} diff --git a/src/ratelimit_action.rs b/src/ratelimit_action.rs index d093396d..c8d6cf03 100644 --- a/src/ratelimit_action.rs +++ b/src/ratelimit_action.rs @@ -1,10 +1,13 @@ use crate::configuration::{Action, DataType, FailureMode, Service}; use crate::data::Expression; use crate::data::Predicate; -use crate::envoy::{RateLimitDescriptor, RateLimitDescriptor_Entry}; -use crate::service::GrpcService; +use crate::envoy::{ + HeaderValue, RateLimitDescriptor, RateLimitDescriptor_Entry, RateLimitResponse, + RateLimitResponse_Code, StatusCode, +}; +use crate::service::{GrpcErrResponse, GrpcService, Headers}; use cel_interpreter::Value; -use log::error; +use log::{debug, error}; use protobuf::RepeatedField; use std::rc::Rc; @@ -160,6 +163,55 @@ impl RateLimitAction { } Some(other) } + + pub fn process_response( + &self, + rate_limit_response: RateLimitResponse, + ) -> Result, GrpcErrResponse> { + match rate_limit_response { + RateLimitResponse { + overall_code: RateLimitResponse_Code::UNKNOWN, + .. + } => { + debug!("process_response(rl): received UNKNOWN response"); + match self.get_failure_mode() { + FailureMode::Deny => Err(GrpcErrResponse::new_internal_server_error()), + FailureMode::Allow => { + debug!("process_response(rl): continuing as FailureMode Allow"); + Ok(Vec::default()) + } + } + } + RateLimitResponse { + overall_code: RateLimitResponse_Code::OVER_LIMIT, + response_headers_to_add: rl_headers, + .. + } => { + debug!("process_response(rl): received OVER_LIMIT response"); + let response_headers = Self::get_header_vec(rl_headers); + Err(GrpcErrResponse::new( + StatusCode::TooManyRequests as u32, + response_headers, + "Too Many Requests\n".to_string(), + )) + } + RateLimitResponse { + overall_code: RateLimitResponse_Code::OK, + response_headers_to_add: additional_headers, + .. + } => { + debug!("process_response(rl): received OK response"); + Ok(Self::get_header_vec(additional_headers)) + } + } + } + + fn get_header_vec(headers: RepeatedField) -> Headers { + headers + .iter() + .map(|header| (header.key.to_owned(), header.value.to_owned())) + .collect() + } } #[cfg(test)] @@ -171,10 +223,14 @@ mod test { }; fn build_service() -> Service { + build_service_with_failure_mode(FailureMode::default()) + } + + fn build_service_with_failure_mode(failure_mode: FailureMode) -> Service { Service { service_type: ServiceType::RateLimit, endpoint: "some_endpoint".into(), - failure_mode: FailureMode::default(), + failure_mode, timeout: Timeout::default(), } } @@ -188,6 +244,35 @@ mod test { } } + fn build_ratelimit_response( + status: RateLimitResponse_Code, + headers: Option>, + ) -> RateLimitResponse { + let mut response = RateLimitResponse::new(); + response.set_overall_code(status); + match status { + RateLimitResponse_Code::UNKNOWN => {} + RateLimitResponse_Code::OVER_LIMIT | RateLimitResponse_Code::OK => { + if let Some(header_list) = headers { + response.set_response_headers_to_add(build_headers(header_list)) + } + } + } + response + } + + fn build_headers(headers: Vec<(&str, &str)>) -> RepeatedField { + headers + .into_iter() + .map(|(key, value)| { + let mut hv = HeaderValue::new(); + hv.set_key(key.to_string()); + hv.set_value(value.to_string()); + hv + }) + .collect::>() + } + #[test] fn empty_predicates_do_apply() { let action = build_action(Vec::default(), Vec::default()); @@ -317,4 +402,109 @@ mod test { assert_eq!(descriptor.get_entries()[1].key, String::from("key_3")); assert_eq!(descriptor.get_entries()[1].value, String::from("value_3")); } + + #[test] + fn process_ok_response() { + let action = build_action(Vec::default(), Vec::default()); + let service = build_service(); + let rl_action = RateLimitAction::new(&action, &service) + .expect("action building failed. Maybe predicates compilation?"); + + let ok_response_without_headers = + build_ratelimit_response(RateLimitResponse_Code::OK, None); + let result = rl_action.process_response(ok_response_without_headers); + assert!(result.is_ok()); + + let headers = result.expect("is ok"); + assert!(headers.is_empty()); + + let ok_response_with_header = build_ratelimit_response( + RateLimitResponse_Code::OK, + Some(vec![("my_header", "my_value")]), + ); + let result = rl_action.process_response(ok_response_with_header); + assert!(result.is_ok()); + + let headers = result.expect("is ok"); + assert!(!headers.is_empty()); + + assert_eq!( + headers[0], + ("my_header".to_string(), "my_value".to_string()) + ); + } + + #[test] + fn process_overlimit_response() { + let headers = vec![("x-ratelimit-limit", "10"), ("x-ratelimit-remaining", "0")]; + let action = build_action(Vec::default(), Vec::default()); + let service = build_service(); + let rl_action = RateLimitAction::new(&action, &service) + .expect("action building failed. Maybe predicates compilation?"); + + let overlimit_response_empty = + build_ratelimit_response(RateLimitResponse_Code::OVER_LIMIT, None); + let result = rl_action.process_response(overlimit_response_empty); + assert!(result.is_err()); + + let grpc_err_response = result.expect_err("is err"); + assert_eq!( + grpc_err_response.status_code(), + StatusCode::TooManyRequests as u32 + ); + assert!(grpc_err_response.headers().is_empty()); + assert_eq!(grpc_err_response.body(), "Too Many Requests\n"); + + let denied_response_headers = + build_ratelimit_response(RateLimitResponse_Code::OVER_LIMIT, Some(headers.clone())); + let result = rl_action.process_response(denied_response_headers); + assert!(result.is_err()); + + let grpc_err_response = result.expect_err("is err"); + assert_eq!( + grpc_err_response.status_code(), + StatusCode::TooManyRequests as u32 + ); + + let response_headers = grpc_err_response.headers(); + headers.iter().zip(response_headers.iter()).for_each( + |((header_one, value_one), (header_two, value_two))| { + assert_eq!(header_one, header_two); + assert_eq!(value_one, value_two); + }, + ); + + assert_eq!(grpc_err_response.body(), "Too Many Requests\n"); + } + + #[test] + fn process_error_response() { + let action = build_action(Vec::default(), Vec::default()); + let deny_service = build_service_with_failure_mode(FailureMode::Deny); + let rl_action = RateLimitAction::new(&action, &deny_service) + .expect("action building failed. Maybe predicates compilation?"); + + let error_response = build_ratelimit_response(RateLimitResponse_Code::UNKNOWN, None); + let result = rl_action.process_response(error_response.clone()); + assert!(result.is_err()); + + let grpc_err_response = result.expect_err("is err"); + assert_eq!( + grpc_err_response.status_code(), + StatusCode::InternalServerError as u32 + ); + + assert!(grpc_err_response.headers().is_empty()); + assert_eq!(grpc_err_response.body(), "Internal Server Error.\n"); + + let allow_service = build_service_with_failure_mode(FailureMode::Allow); + let rl_action = RateLimitAction::new(&action, &allow_service) + .expect("action building failed. Maybe predicates compilation?"); + + let result = rl_action.process_response(error_response); + assert!(result.is_ok()); + + let headers = result.expect("is ok"); + assert!(headers.is_empty()); + } } diff --git a/src/runtime_action.rs b/src/runtime_action.rs index 464b04f4..7b384c39 100644 --- a/src/runtime_action.rs +++ b/src/runtime_action.rs @@ -1,10 +1,13 @@ use crate::auth_action::AuthAction; use crate::configuration::{Action, FailureMode, Service, ServiceType}; use crate::ratelimit_action::RateLimitAction; -use crate::service::GrpcService; +use crate::service::auth::AuthService; +use crate::service::rate_limit::RateLimitService; +use crate::service::{GrpcErrResponse, GrpcRequest, GrpcService, Headers}; +use log::debug; +use protobuf::Message; use std::collections::HashMap; use std::rc::Rc; -use std::time::Duration; #[derive(Debug)] pub enum RuntimeAction { @@ -45,14 +48,6 @@ impl RuntimeAction { } } - pub fn get_timeout(&self) -> Duration { - self.grpc_service().get_timeout() - } - - pub fn get_service_type(&self) -> ServiceType { - self.grpc_service().get_service_type() - } - #[must_use] pub fn merge(&mut self, other: RuntimeAction) -> Option { // only makes sense for rate limiting actions @@ -63,6 +58,65 @@ impl RuntimeAction { } Some(other) } + + pub fn process_request(&self) -> Option { + if !self.conditions_apply() { + None + } else { + self.grpc_service().build_request(self.build_message()) + } + } + + pub fn process_response(&self, msg: &[u8]) -> Result { + match self { + Self::Auth(auth_action) => match Message::parse_from_bytes(msg) { + Ok(check_response) => auth_action.process_response(check_response), + Err(e) => { + debug!("process_response(auth): failed to parse response `{e:?}`"); + match self.get_failure_mode() { + FailureMode::Deny => Err(GrpcErrResponse::new_internal_server_error()), + FailureMode::Allow => { + debug!("process_response(auth): continuing as FailureMode Allow"); + Ok(Vec::default()) + } + } + } + }, + Self::RateLimit(rl_action) => match Message::parse_from_bytes(msg) { + Ok(rate_limit_response) => rl_action.process_response(rate_limit_response), + Err(e) => { + debug!("process_response(rl): failed to parse response `{e:?}`"); + match self.get_failure_mode() { + FailureMode::Deny => Err(GrpcErrResponse::new_internal_server_error()), + FailureMode::Allow => { + debug!("process_response(rl): continuing as FailureMode Allow"); + Ok(Vec::default()) + } + } + } + }, + } + } + + pub fn build_message(&self) -> Option> { + match self { + RuntimeAction::RateLimit(rl_action) => { + let descriptor = rl_action.build_descriptor(); + if descriptor.entries.is_empty() { + debug!("build_message(rl): empty descriptors"); + None + } else { + RateLimitService::request_message_as_bytes( + String::from(rl_action.scope()), + vec![descriptor].into(), + ) + } + } + RuntimeAction::Auth(auth_action) => { + AuthService::request_message_as_bytes(String::from(auth_action.scope())) + } + } + } } #[cfg(test)] diff --git a/src/runtime_action_set.rs b/src/runtime_action_set.rs index 840ed6e2..ae7a1369 100644 --- a/src/runtime_action_set.rs +++ b/src/runtime_action_set.rs @@ -1,7 +1,7 @@ use crate::configuration::{ActionSet, Service}; -use crate::data::Predicate; +use crate::data::{Predicate, PredicateVec}; use crate::runtime_action::RuntimeAction; -use log::error; +use crate::service::{GrpcErrResponse, Headers, IndexedGrpcRequest}; use std::collections::HashMap; use std::rc::Rc; @@ -57,14 +57,35 @@ impl RuntimeActionSet { } pub fn conditions_apply(&self) -> bool { - let predicates = &self.route_rule_predicates; - predicates.is_empty() - || predicates.iter().all(|predicate| match predicate.test() { - Ok(b) => b, - Err(err) => { - error!("Failed to evaluate {:?}: {}", predicate, err); - panic!("Err out of this!") - } + self.route_rule_predicates.apply() + } + + pub fn find_first_grpc_request(&self) -> Option { + self.find_next_grpc_request(0) + } + + pub fn find_next_grpc_request(&self, start: usize) -> Option { + self.runtime_actions + .iter() + .skip(start) + .enumerate() + .find_map(|(i, action)| { + action + .process_request() + .map(|request| IndexedGrpcRequest::new(start + i, request)) + }) + } + + pub fn process_grpc_response( + &self, + index: usize, + msg: &[u8], + ) -> Result<(Option, Headers), GrpcErrResponse> { + self.runtime_actions[index] + .process_response(msg) + .map(|headers| { + let next_msg = self.find_next_grpc_request(index + 1); + (next_msg, headers) }) } } diff --git a/src/runtime_config.rs b/src/runtime_config.rs index 25c63cbb..58db3a9b 100644 --- a/src/runtime_config.rs +++ b/src/runtime_config.rs @@ -3,11 +3,7 @@ use crate::configuration::PluginConfiguration; use crate::runtime_action_set::RuntimeActionSet; use std::rc::Rc; -pub(crate) struct RuntimeConfig { - pub index: ActionSetIndex, -} - -impl TryFrom for RuntimeConfig { +impl TryFrom for ActionSetIndex { type Error = String; fn try_from(config: PluginConfiguration) -> Result { @@ -19,15 +15,13 @@ impl TryFrom for RuntimeConfig { } } - Ok(Self { index }) + Ok(index) } } -impl Default for RuntimeConfig { +impl Default for ActionSetIndex { fn default() -> Self { - Self { - index: ActionSetIndex::new(), - } + ActionSetIndex::new() } } @@ -97,21 +91,15 @@ mod test { } assert!(res.is_ok()); - let result = RuntimeConfig::try_from(res.unwrap()); - let runtime_config = result.expect("That didn't work"); - let rlp_option = runtime_config - .index - .get_longest_match_action_sets("example.com"); + let result = ActionSetIndex::try_from(res.unwrap()); + let index = result.expect("That didn't work"); + let rlp_option = index.get_longest_match_action_sets("example.com"); assert!(rlp_option.is_some()); - let rlp_option = runtime_config - .index - .get_longest_match_action_sets("test.toystore.com"); + let rlp_option = index.get_longest_match_action_sets("test.toystore.com"); assert!(rlp_option.is_some()); - let rlp_option = runtime_config - .index - .get_longest_match_action_sets("unknown"); + let rlp_option = index.get_longest_match_action_sets("unknown"); assert!(rlp_option.is_none()); } @@ -151,7 +139,7 @@ mod test { } assert!(serde_res.is_ok()); - let result = RuntimeConfig::try_from(serde_res.expect("That didn't work")); + let result = ActionSetIndex::try_from(serde_res.expect("That didn't work")); assert_eq!(result.err(), Some("Unknown service: unknown".into())); } } diff --git a/src/service.rs b/src/service.rs index 1145c1e0..5c573547 100644 --- a/src/service.rs +++ b/src/service.rs @@ -1,20 +1,12 @@ pub(crate) mod auth; -pub(crate) mod grpc_message; pub(crate) mod rate_limit; use crate::configuration::{FailureMode, Service, ServiceType}; use crate::envoy::StatusCode; -use crate::operation_dispatcher::Operation; -use crate::runtime_action::RuntimeAction; -use crate::service::auth::{AuthService, AUTH_METHOD_NAME, AUTH_SERVICE_NAME}; -use crate::service::grpc_message::{GrpcMessageRequest, GrpcMessageResponse}; -use crate::service::rate_limit::{RateLimitService, RATELIMIT_METHOD_NAME, RATELIMIT_SERVICE_NAME}; +use crate::service::auth::{AUTH_METHOD_NAME, AUTH_SERVICE_NAME}; +use crate::service::rate_limit::{RATELIMIT_METHOD_NAME, RATELIMIT_SERVICE_NAME}; use crate::service::TracingHeader::{Baggage, Traceparent, Tracestate}; -use log::warn; -use protobuf::Message; -use proxy_wasm::hostcalls; -use proxy_wasm::types::Status::SerializationFailure; -use proxy_wasm::types::{BufferType, Bytes, MapType, Status}; +use proxy_wasm::types::Bytes; use std::cell::OnceCell; use std::rc::Rc; use std::time::Duration; @@ -46,10 +38,6 @@ impl GrpcService { self.service.timeout.0 } - pub fn get_service_type(&self) -> ServiceType { - self.service.service_type.clone() - } - pub fn get_failure_mode(&self) -> FailureMode { self.service.failure_mode } @@ -63,115 +51,123 @@ impl GrpcService { fn method(&self) -> &str { self.method } + pub fn build_request(&self, message: Option>) -> Option { + message.map(|msg| { + GrpcRequest::new( + self.endpoint(), + self.name(), + self.method(), + self.get_timeout(), + Some(msg), + ) + }) + } +} - pub fn process_grpc_response( - operation: Rc, - resp_size: usize, - ) -> Result { - let failure_mode = operation.get_failure_mode(); - if let Ok(Some(res_body_bytes)) = - hostcalls::get_buffer(BufferType::GrpcReceiveBuffer, 0, resp_size) - { - match GrpcMessageResponse::new(&operation.get_service_type(), &res_body_bytes) { - Ok(res) => match operation.get_service_type() { - ServiceType::Auth => AuthService::process_auth_grpc_response(res, failure_mode), - ServiceType::RateLimit => { - RateLimitService::process_ratelimit_grpc_response(res, failure_mode) - } - }, - Err(e) => { - warn!( - "failed to parse grpc response body into GrpcMessageResponse message: {e}" - ); - GrpcService::handle_error_on_grpc_response(failure_mode); - Err(StatusCode::InternalServerError) - } - } - } else { - warn!("failed to get grpc buffer or return data is null!"); - GrpcService::handle_error_on_grpc_response(failure_mode); - Err(StatusCode::InternalServerError) - } +pub struct IndexedGrpcRequest { + index: usize, + request: GrpcRequest, +} + +impl IndexedGrpcRequest { + pub(crate) fn new(index: usize, request: GrpcRequest) -> Self { + Self { index, request } } - pub fn handle_error_on_grpc_response(failure_mode: FailureMode) { - match failure_mode { - FailureMode::Deny => { - hostcalls::send_http_response(500, vec![], Some(b"Internal Server Error.\n")) - .expect("failed to send_http_response 500"); - } - FailureMode::Allow => {} - } + pub fn index(&self) -> usize { + self.index + } + + pub fn request(self) -> GrpcRequest { + self.request } } -pub struct GrpcResult { - pub response_headers: Vec<(String, String)>, +// GrpcRequest contains the information required to make a Grpc Call +pub struct GrpcRequest { + upstream_name: String, + service_name: String, + method_name: String, + timeout: Duration, + message: Option>, } -impl GrpcResult { - pub fn default() -> Self { + +impl GrpcRequest { + pub fn new( + upstream_name: &str, + service_name: &str, + method_name: &str, + timeout: Duration, + message: Option>, + ) -> Self { Self { - response_headers: Vec::new(), + upstream_name: upstream_name.to_owned(), + service_name: service_name.to_owned(), + method_name: method_name.to_owned(), + timeout, + message, } } - pub fn new(response_headers: Vec<(String, String)>) -> Self { - Self { response_headers } + + pub fn upstream_name(&self) -> &str { + &self.upstream_name } -} -pub type GrpcCallFn = fn( - upstream_name: &str, - service_name: &str, - method_name: &str, - initial_metadata: Vec<(&str, &[u8])>, - message: Option<&[u8]>, - timeout: Duration, -) -> Result; + pub fn service_name(&self) -> &str { + &self.service_name + } -pub type GetMapValuesBytesFn = fn(map_type: MapType, key: &str) -> Result, Status>; + pub fn method_name(&self) -> &str { + &self.method_name + } -pub type GrpcMessageBuildFn = fn(action: &RuntimeAction) -> Option; + pub fn timeout(&self) -> Duration { + self.timeout + } + pub fn message(&self) -> Option<&[u8]> { + self.message.as_deref() + } +} + +pub type Headers = Vec<(String, String)>; #[derive(Debug)] -pub struct GrpcServiceHandler { - grpc_service: Rc, - header_resolver: Rc, +pub struct GrpcErrResponse { + status_code: u32, + response_headers: Headers, + body: String, } -impl GrpcServiceHandler { - pub fn new(grpc_service: Rc, header_resolver: Rc) -> Self { +impl GrpcErrResponse { + pub fn new(status_code: u32, response_headers: Headers, body: String) -> Self { Self { - grpc_service, - header_resolver, + status_code, + response_headers, + body, } } - pub fn send( - &self, - get_map_values_bytes_fn: GetMapValuesBytesFn, - grpc_call_fn: GrpcCallFn, - message: GrpcMessageRequest, - timeout: Duration, - ) -> Result { - let msg = Message::write_to_bytes(&message).map_err(|e| { - warn!("Failed to write protobuf message to bytes: {e:?}"); - SerializationFailure - })?; - let metadata = self - .header_resolver - .get(get_map_values_bytes_fn) + pub fn new_internal_server_error() -> Self { + Self { + status_code: StatusCode::InternalServerError as u32, + response_headers: Vec::default(), + body: "Internal Server Error.\n".to_string(), + } + } + + pub fn status_code(&self) -> u32 { + self.status_code + } + + pub fn headers(&self) -> Vec<(&str, &str)> { + self.response_headers .iter() - .map(|(header, value)| (*header, value.as_slice())) - .collect(); - - grpc_call_fn( - self.grpc_service.endpoint(), - self.grpc_service.name(), - self.grpc_service.method(), - metadata, - Some(&msg), - timeout, - ) + .map(|(header, value)| (header.as_str(), value.as_str())) + .collect() + } + + pub fn body(&self) -> &str { + self.body.as_str() } } @@ -193,13 +189,14 @@ impl HeaderResolver { } } - pub fn get(&self, get_map_values_bytes_fn: GetMapValuesBytesFn) -> &Vec<(&'static str, Bytes)> { + pub fn get_with_ctx( + &self, + ctx: &T, + ) -> &Vec<(&'static str, Bytes)> { self.headers.get_or_init(|| { let mut headers = Vec::new(); for header in TracingHeader::all() { - if let Ok(Some(value)) = - get_map_values_bytes_fn(MapType::HttpRequestHeaders, (*header).as_str()) - { + if let Some(value) = ctx.get_http_request_header_bytes((*header).as_str()) { headers.push(((*header).as_str(), value)); } } @@ -228,3 +225,45 @@ impl TracingHeader { } } } + +#[cfg(test)] +mod test { + use super::*; + use proxy_wasm::traits::Context; + use std::collections::HashMap; + + struct MockHost { + headers: HashMap<&'static str, Bytes>, + } + + impl MockHost { + pub fn new(headers: HashMap<&'static str, Bytes>) -> Self { + Self { headers } + } + } + + impl Context for MockHost {} + + impl proxy_wasm::traits::HttpContext for MockHost { + fn get_http_request_header_bytes(&self, name: &str) -> Option { + self.headers.get(name).map(|b| b.to_owned()) + } + } + + #[test] + fn read_headers() { + let header_resolver = HeaderResolver::new(); + + let headers: Vec<(&str, Bytes)> = vec![("traceparent", b"xyz".to_vec())]; + let mock_host = MockHost::new(headers.iter().cloned().collect::>()); + + let resolver_headers = header_resolver.get_with_ctx(&mock_host); + + headers.iter().zip(resolver_headers.iter()).for_each( + |((header_one, value_one), (header_two, value_two))| { + assert_eq!(header_one, header_two); + assert_eq!(value_one, value_two); + }, + ) + } +} diff --git a/src/service/auth.rs b/src/service/auth.rs index 925c4762..9994398d 100644 --- a/src/service/auth.rs +++ b/src/service/auth.rs @@ -1,18 +1,14 @@ -use crate::configuration::FailureMode; -use crate::data::{get_attribute, store_metadata}; +use crate::data::get_attribute; use crate::envoy::{ Address, AttributeContext, AttributeContext_HttpRequest, AttributeContext_Peer, - AttributeContext_Request, CheckRequest, CheckResponse_oneof_http_response, Metadata, - SocketAddress, StatusCode, + AttributeContext_Request, CheckRequest, Metadata, SocketAddress, }; -use crate::service::grpc_message::{GrpcMessageResponse, GrpcMessageResult}; -use crate::service::{GrpcResult, GrpcService}; use chrono::{DateTime, FixedOffset}; -use log::{debug, warn}; +use log::debug; use protobuf::well_known_types::Timestamp; use protobuf::Message; use proxy_wasm::hostcalls; -use proxy_wasm::types::{Bytes, MapType}; +use proxy_wasm::types::MapType; use std::collections::HashMap; pub const AUTH_SERVICE_NAME: &str = "envoy.service.auth.v3.Authorization"; @@ -25,11 +21,11 @@ impl AuthService { AuthService::build_check_req(ce_host) } - pub fn response_message(res_body_bytes: &Bytes) -> GrpcMessageResult { - match Message::parse_from_bytes(res_body_bytes) { - Ok(res) => Ok(GrpcMessageResponse::Auth(res)), - Err(e) => Err(e), - } + pub fn request_message_as_bytes(ce_host: String) -> Option> { + Self::request_message(ce_host) + .write_to_bytes() + .map_err(|e| debug!("Failed to write protobuf message to bytes: {e:?}")) + .ok() } fn build_check_req(ce_host: String) -> CheckRequest { @@ -121,67 +117,4 @@ impl AuthService { peer.set_address(address); peer } - - pub fn process_auth_grpc_response( - auth_resp: GrpcMessageResponse, - failure_mode: FailureMode, - ) -> Result { - if let GrpcMessageResponse::Auth(check_response) = auth_resp { - // store dynamic metadata in filter state - store_metadata(check_response.get_dynamic_metadata()); - - match check_response.http_response { - Some(CheckResponse_oneof_http_response::ok_response(ok_response)) => { - debug!("process_auth_grpc_response: received OkHttpResponse"); - if !ok_response.get_response_headers_to_add().is_empty() { - panic!("process_auth_grpc_response: response contained response_headers_to_add which is unsupported!") - } - if !ok_response.get_headers_to_remove().is_empty() { - panic!("process_auth_grpc_response: response contained headers_to_remove which is unsupported!") - } - if !ok_response.get_query_parameters_to_set().is_empty() { - panic!("process_auth_grpc_response: response contained query_parameters_to_set which is unsupported!") - } - if !ok_response.get_query_parameters_to_remove().is_empty() { - panic!("process_auth_grpc_response: response contained query_parameters_to_remove which is unsupported!") - } - ok_response.get_headers().iter().for_each(|header| { - hostcalls::add_map_value( - MapType::HttpRequestHeaders, - header.get_header().get_key(), - header.get_header().get_value(), - ) - .expect("failed to add_map_value to HttpRequestHeaders") - }); - Ok(GrpcResult::default()) - } - Some(CheckResponse_oneof_http_response::denied_response(denied_response)) => { - debug!("process_auth_grpc_response: received DeniedHttpResponse"); - let mut response_headers = vec![]; - let status_code = denied_response.get_status().code; - denied_response.get_headers().iter().for_each(|header| { - response_headers.push(( - header.get_header().get_key(), - header.get_header().get_value(), - )) - }); - hostcalls::send_http_response( - status_code as u32, - response_headers, - Some(denied_response.get_body().as_ref()), - ) - .expect("failed to send_http_response"); - Err(status_code) - } - None => { - GrpcService::handle_error_on_grpc_response(failure_mode); - Err(StatusCode::InternalServerError) - } - } - } else { - warn!("not a GrpcMessageResponse::Auth(CheckResponse)!"); - GrpcService::handle_error_on_grpc_response(failure_mode); - Err(StatusCode::InternalServerError) - } - } } diff --git a/src/service/grpc_message.rs b/src/service/grpc_message.rs deleted file mode 100644 index 00148319..00000000 --- a/src/service/grpc_message.rs +++ /dev/null @@ -1,269 +0,0 @@ -use crate::configuration::ServiceType; -use crate::envoy::{CheckRequest, CheckResponse, RateLimitRequest, RateLimitResponse}; -use crate::runtime_action::RuntimeAction; -use crate::service::auth::AuthService; -use crate::service::rate_limit::RateLimitService; -use log::debug; -use protobuf::reflect::MessageDescriptor; -use protobuf::{ - Clear, CodedInputStream, CodedOutputStream, Message, ProtobufError, ProtobufResult, - UnknownFields, -}; -use proxy_wasm::types::Bytes; -use std::any::Any; - -#[derive(Clone, Debug)] -pub enum GrpcMessageRequest { - Auth(CheckRequest), - RateLimit(RateLimitRequest), -} - -impl Default for GrpcMessageRequest { - fn default() -> Self { - GrpcMessageRequest::RateLimit(RateLimitRequest::new()) - } -} - -impl Clear for GrpcMessageRequest { - fn clear(&mut self) { - match self { - GrpcMessageRequest::Auth(msg) => msg.clear(), - GrpcMessageRequest::RateLimit(msg) => msg.clear(), - } - } -} - -impl Message for GrpcMessageRequest { - fn descriptor(&self) -> &'static MessageDescriptor { - match self { - GrpcMessageRequest::Auth(msg) => msg.descriptor(), - GrpcMessageRequest::RateLimit(msg) => msg.descriptor(), - } - } - - fn is_initialized(&self) -> bool { - match self { - GrpcMessageRequest::Auth(msg) => msg.is_initialized(), - GrpcMessageRequest::RateLimit(msg) => msg.is_initialized(), - } - } - - fn merge_from(&mut self, is: &mut CodedInputStream) -> ProtobufResult<()> { - match self { - GrpcMessageRequest::Auth(msg) => msg.merge_from(is), - GrpcMessageRequest::RateLimit(msg) => msg.merge_from(is), - } - } - - fn write_to_with_cached_sizes(&self, os: &mut CodedOutputStream) -> ProtobufResult<()> { - match self { - GrpcMessageRequest::Auth(msg) => msg.write_to_with_cached_sizes(os), - GrpcMessageRequest::RateLimit(msg) => msg.write_to_with_cached_sizes(os), - } - } - - fn write_to_bytes(&self) -> ProtobufResult> { - match self { - GrpcMessageRequest::Auth(msg) => msg.write_to_bytes(), - GrpcMessageRequest::RateLimit(msg) => msg.write_to_bytes(), - } - } - - fn compute_size(&self) -> u32 { - match self { - GrpcMessageRequest::Auth(msg) => msg.compute_size(), - GrpcMessageRequest::RateLimit(msg) => msg.compute_size(), - } - } - - fn get_cached_size(&self) -> u32 { - match self { - GrpcMessageRequest::Auth(msg) => msg.get_cached_size(), - GrpcMessageRequest::RateLimit(msg) => msg.get_cached_size(), - } - } - - fn get_unknown_fields(&self) -> &UnknownFields { - match self { - GrpcMessageRequest::Auth(msg) => msg.get_unknown_fields(), - GrpcMessageRequest::RateLimit(msg) => msg.get_unknown_fields(), - } - } - - fn mut_unknown_fields(&mut self) -> &mut UnknownFields { - match self { - GrpcMessageRequest::Auth(msg) => msg.mut_unknown_fields(), - GrpcMessageRequest::RateLimit(msg) => msg.mut_unknown_fields(), - } - } - - fn as_any(&self) -> &dyn Any { - match self { - GrpcMessageRequest::Auth(msg) => msg.as_any(), - GrpcMessageRequest::RateLimit(msg) => msg.as_any(), - } - } - - fn new() -> Self - where - Self: Sized, - { - // Returning default value - GrpcMessageRequest::default() - } - - fn default_instance() -> &'static Self - where - Self: Sized, - { - #[allow(non_upper_case_globals)] - static instance: ::protobuf::rt::LazyV2 = ::protobuf::rt::LazyV2::INIT; - instance.get(|| GrpcMessageRequest::RateLimit(RateLimitRequest::new())) - } -} - -impl GrpcMessageRequest { - // Using domain as ce_host for the time being, we might pass a DataType in the future. - pub fn new(action: &RuntimeAction) -> Option { - match action { - RuntimeAction::RateLimit(rl_action) => { - let descriptor = rl_action.build_descriptor(); - if descriptor.entries.is_empty() { - debug!("grpc_message_request: empty descriptors"); - None - } else { - Some(GrpcMessageRequest::RateLimit( - RateLimitService::request_message( - String::from(rl_action.scope()), - vec![descriptor].into(), - ), - )) - } - } - RuntimeAction::Auth(auth_action) => Some(GrpcMessageRequest::Auth( - AuthService::request_message(String::from(auth_action.scope())), - )), - } - } -} - -#[derive(Clone, Debug)] -pub enum GrpcMessageResponse { - Auth(CheckResponse), - RateLimit(RateLimitResponse), -} - -impl Default for GrpcMessageResponse { - fn default() -> Self { - GrpcMessageResponse::RateLimit(RateLimitResponse::new()) - } -} - -impl Clear for GrpcMessageResponse { - fn clear(&mut self) { - todo!() - } -} - -impl Message for GrpcMessageResponse { - fn descriptor(&self) -> &'static MessageDescriptor { - match self { - GrpcMessageResponse::Auth(res) => res.descriptor(), - GrpcMessageResponse::RateLimit(res) => res.descriptor(), - } - } - - fn is_initialized(&self) -> bool { - match self { - GrpcMessageResponse::Auth(res) => res.is_initialized(), - GrpcMessageResponse::RateLimit(res) => res.is_initialized(), - } - } - - fn merge_from(&mut self, is: &mut CodedInputStream) -> ProtobufResult<()> { - match self { - GrpcMessageResponse::Auth(res) => res.merge_from(is), - GrpcMessageResponse::RateLimit(res) => res.merge_from(is), - } - } - - fn write_to_with_cached_sizes(&self, os: &mut CodedOutputStream) -> ProtobufResult<()> { - match self { - GrpcMessageResponse::Auth(res) => res.write_to_with_cached_sizes(os), - GrpcMessageResponse::RateLimit(res) => res.write_to_with_cached_sizes(os), - } - } - - fn write_to_bytes(&self) -> ProtobufResult> { - match self { - GrpcMessageResponse::Auth(res) => res.write_to_bytes(), - GrpcMessageResponse::RateLimit(res) => res.write_to_bytes(), - } - } - - fn compute_size(&self) -> u32 { - match self { - GrpcMessageResponse::Auth(res) => res.compute_size(), - GrpcMessageResponse::RateLimit(res) => res.compute_size(), - } - } - - fn get_cached_size(&self) -> u32 { - match self { - GrpcMessageResponse::Auth(res) => res.get_cached_size(), - GrpcMessageResponse::RateLimit(res) => res.get_cached_size(), - } - } - - fn get_unknown_fields(&self) -> &UnknownFields { - match self { - GrpcMessageResponse::Auth(res) => res.get_unknown_fields(), - GrpcMessageResponse::RateLimit(res) => res.get_unknown_fields(), - } - } - - fn mut_unknown_fields(&mut self) -> &mut UnknownFields { - match self { - GrpcMessageResponse::Auth(res) => res.mut_unknown_fields(), - GrpcMessageResponse::RateLimit(res) => res.mut_unknown_fields(), - } - } - - fn as_any(&self) -> &dyn Any { - match self { - GrpcMessageResponse::Auth(res) => res.as_any(), - GrpcMessageResponse::RateLimit(res) => res.as_any(), - } - } - - fn new() -> Self - where - Self: Sized, - { - // Returning default value - GrpcMessageResponse::default() - } - - fn default_instance() -> &'static Self - where - Self: Sized, - { - #[allow(non_upper_case_globals)] - static instance: ::protobuf::rt::LazyV2 = ::protobuf::rt::LazyV2::INIT; - instance.get(|| GrpcMessageResponse::RateLimit(RateLimitResponse::new())) - } -} - -impl GrpcMessageResponse { - pub fn new( - service_type: &ServiceType, - res_body_bytes: &Bytes, - ) -> GrpcMessageResult { - match service_type { - ServiceType::RateLimit => RateLimitService::response_message(res_body_bytes), - ServiceType::Auth => AuthService::response_message(res_body_bytes), - } - } -} - -pub type GrpcMessageResult = Result; diff --git a/src/service/rate_limit.rs b/src/service/rate_limit.rs index b4dcf809..c4842c48 100644 --- a/src/service/rate_limit.rs +++ b/src/service/rate_limit.rs @@ -1,13 +1,6 @@ -use crate::configuration::FailureMode; -use crate::envoy::{ - RateLimitDescriptor, RateLimitRequest, RateLimitResponse, RateLimitResponse_Code, StatusCode, -}; -use crate::service::grpc_message::{GrpcMessageResponse, GrpcMessageResult}; -use crate::service::{GrpcResult, GrpcService}; -use log::warn; +use crate::envoy::{RateLimitDescriptor, RateLimitRequest}; +use log::debug; use protobuf::{Message, RepeatedField}; -use proxy_wasm::hostcalls; -use proxy_wasm::types::Bytes; pub const RATELIMIT_SERVICE_NAME: &str = "envoy.service.ratelimit.v3.RateLimitService"; pub const RATELIMIT_METHOD_NAME: &str = "ShouldRateLimit"; @@ -28,57 +21,14 @@ impl RateLimitService { } } - pub fn response_message(res_body_bytes: &Bytes) -> GrpcMessageResult { - match Message::parse_from_bytes(res_body_bytes) { - Ok(res) => Ok(GrpcMessageResponse::RateLimit(res)), - Err(e) => Err(e), - } - } - - pub fn process_ratelimit_grpc_response( - rl_resp: GrpcMessageResponse, - failure_mode: FailureMode, - ) -> Result { - match rl_resp { - GrpcMessageResponse::RateLimit(RateLimitResponse { - overall_code: RateLimitResponse_Code::UNKNOWN, - .. - }) => { - GrpcService::handle_error_on_grpc_response(failure_mode); - Err(StatusCode::InternalServerError) - } - GrpcMessageResponse::RateLimit(RateLimitResponse { - overall_code: RateLimitResponse_Code::OVER_LIMIT, - response_headers_to_add: rl_headers, - .. - }) => { - let mut response_headers = vec![]; - for header in &rl_headers { - response_headers.push((header.get_key(), header.get_value())); - } - hostcalls::send_http_response(429, response_headers, Some(b"Too Many Requests\n")) - .expect("failed to send_http_response 429 while OVER_LIMIT"); - Err(StatusCode::TooManyRequests) - } - GrpcMessageResponse::RateLimit(RateLimitResponse { - overall_code: RateLimitResponse_Code::OK, - response_headers_to_add: additional_headers, - .. - }) => { - let result = GrpcResult::new( - additional_headers - .iter() - .map(|header| (header.get_key().to_owned(), header.get_value().to_owned())) - .collect(), - ); - Ok(result) - } - _ => { - warn!("not a valid GrpcMessageResponse::RateLimit(RateLimitResponse)!"); - GrpcService::handle_error_on_grpc_response(failure_mode); - Err(StatusCode::InternalServerError) - } - } + pub fn request_message_as_bytes( + domain: String, + descriptors: RepeatedField, + ) -> Option> { + Self::request_message(domain, descriptors) + .write_to_bytes() + .map_err(|e| debug!("Failed to write protobuf message to bytes: {e:?}")) + .ok() } } diff --git a/tests/auth.rs b/tests/auth.rs index c11e533b..1a05e77f 100644 --- a/tests/auth.rs +++ b/tests/auth.rs @@ -164,6 +164,10 @@ fn it_auths() { ) .expect_get_property(Some(vec!["source", "port"])) .returning(Some(data::source::port::P_45000)) + .expect_log( + Some(LogLevel::Debug), + Some("handle_operation: SendGrpcRequest"), + ) // retrieving tracing headers .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("traceparent")) .returning(None) @@ -182,7 +186,7 @@ fn it_auths() { .returning(Ok(42)) .expect_log( Some(LogLevel::Debug), - Some("#2 initiated gRPC call (id# 42)"), + Some("handle_operation: AwaitGrpcResponse"), ) .execute_and_expect(ReturnType::Action(Action::Pause)) .unwrap(); @@ -205,8 +209,13 @@ fn it_auths() { .returning(Some(&grpc_response)) .expect_log( Some(LogLevel::Debug), - Some("process_auth_grpc_response: received OkHttpResponse"), + Some("process_response(auth): store_metadata"), ) + .expect_log( + Some(LogLevel::Debug), + Some("process_response(auth): received OkHttpResponse"), + ) + .expect_log(Some(LogLevel::Debug), Some("handle_operation: Done")) .execute_and_expect(ReturnType::None) .unwrap(); @@ -347,6 +356,10 @@ fn it_denies() { ) .expect_get_property(Some(vec!["source", "port"])) .returning(Some(data::source::port::P_45000)) + .expect_log( + Some(LogLevel::Debug), + Some("handle_operation: SendGrpcRequest"), + ) // retrieving tracing headers .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("traceparent")) .returning(None) @@ -365,7 +378,7 @@ fn it_denies() { .returning(Ok(42)) .expect_log( Some(LogLevel::Debug), - Some("#2 initiated gRPC call (id# 42)"), + Some("handle_operation: AwaitGrpcResponse"), ) .execute_and_expect(ReturnType::Action(Action::Pause)) .unwrap(); @@ -388,8 +401,13 @@ fn it_denies() { .returning(Some(&grpc_response)) .expect_log( Some(LogLevel::Debug), - Some("process_auth_grpc_response: received DeniedHttpResponse"), + Some("process_response(auth): store_metadata"), ) + .expect_log( + Some(LogLevel::Debug), + Some("process_response(auth): received DeniedHttpResponse"), + ) + .expect_log(Some(LogLevel::Debug), Some("handle_operation: Die")) .expect_send_local_response( Some(401), None, @@ -548,6 +566,10 @@ fn it_does_not_fold_auth_actions() { ) .expect_get_property(Some(vec!["source", "port"])) .returning(Some(data::source::port::P_45000)) + .expect_log( + Some(LogLevel::Debug), + Some("handle_operation: SendGrpcRequest"), + ) // retrieving tracing headers .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("traceparent")) .returning(None) @@ -566,7 +588,7 @@ fn it_does_not_fold_auth_actions() { .returning(Ok(42)) .expect_log( Some(LogLevel::Debug), - Some("#2 initiated gRPC call (id# 42)"), + Some("handle_operation: AwaitGrpcResponse"), ) .execute_and_expect(ReturnType::Action(Action::Pause)) .unwrap(); @@ -589,7 +611,11 @@ fn it_does_not_fold_auth_actions() { .returning(Some(&grpc_response)) .expect_log( Some(LogLevel::Debug), - Some("process_auth_grpc_response: received OkHttpResponse"), + Some("process_response(auth): store_metadata"), + ) + .expect_log( + Some(LogLevel::Debug), + Some("process_response(auth): received OkHttpResponse"), ) .expect_get_header_map_pairs(Some(MapType::HttpRequestHeaders)) .returning(None) @@ -653,6 +679,10 @@ fn it_does_not_fold_auth_actions() { ) .expect_get_property(Some(vec!["source", "port"])) .returning(Some(data::source::port::P_45000)) + .expect_log( + Some(LogLevel::Debug), + Some("handle_operation: SendGrpcRequest"), + ) .expect_grpc_call( Some("authorino-cluster"), Some("envoy.service.auth.v3.Authorization"), @@ -662,6 +692,10 @@ fn it_does_not_fold_auth_actions() { Some(5000), ) .returning(Ok(42)) + .expect_log( + Some(LogLevel::Debug), + Some("handle_operation: AwaitGrpcResponse"), + ) .execute_and_expect(ReturnType::None) .unwrap(); @@ -683,8 +717,13 @@ fn it_does_not_fold_auth_actions() { .returning(Some(&grpc_response)) .expect_log( Some(LogLevel::Debug), - Some("process_auth_grpc_response: received OkHttpResponse"), + Some("process_response(auth): store_metadata"), + ) + .expect_log( + Some(LogLevel::Debug), + Some("process_response(auth): received OkHttpResponse"), ) + .expect_log(Some(LogLevel::Debug), Some("handle_operation: Done")) .execute_and_expect(ReturnType::None) .unwrap(); diff --git a/tests/failuremode.rs b/tests/failuremode.rs index c05544f2..38daf569 100644 --- a/tests/failuremode.rs +++ b/tests/failuremode.rs @@ -101,6 +101,10 @@ fn it_runs_next_action_on_failure_when_failuremode_is_allow() { Some(LogLevel::Debug), Some("#2 action_set selected some-name"), ) + .expect_log( + Some(LogLevel::Debug), + Some("handle_operation: SendGrpcRequest"), + ) // retrieving tracing headers .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("traceparent")) .returning(None) @@ -119,7 +123,7 @@ fn it_runs_next_action_on_failure_when_failuremode_is_allow() { .returning(Ok(first_call_token_id)) .expect_log( Some(LogLevel::Debug), - Some("#2 initiated gRPC call (id# 42)"), + Some("handle_operation: AwaitGrpcResponse"), ) .execute_and_expect(ReturnType::Action(Action::Pause)) .unwrap(); @@ -132,8 +136,10 @@ fn it_runs_next_action_on_failure_when_failuremode_is_allow() { Some(LogLevel::Debug), Some(format!("#2 on_grpc_call_response: received gRPC call response: token: {first_call_token_id}, status: {status_code}").as_str()), ) - .expect_get_buffer_bytes(Some(BufferType::GrpcReceiveBuffer)) - .returning(Some(&[])) + .expect_log( + Some(LogLevel::Debug), + Some("handle_operation: SendGrpcRequest"), + ) .expect_grpc_call( Some("limitador-cluster"), Some("envoy.service.ratelimit.v3.RateLimitService"), @@ -143,6 +149,10 @@ fn it_runs_next_action_on_failure_when_failuremode_is_allow() { Some(5000), ) .returning(Ok(second_call_token_id)) + .expect_log( + Some(LogLevel::Debug), + Some("handle_operation: AwaitGrpcResponse"), + ) .execute_and_expect(ReturnType::None) .unwrap(); @@ -159,6 +169,14 @@ fn it_runs_next_action_on_failure_when_failuremode_is_allow() { ) .expect_get_buffer_bytes(Some(BufferType::GrpcReceiveBuffer)) .returning(Some(&grpc_response)) + .expect_log( + Some(LogLevel::Debug), + Some("process_response(rl): received OK response"), + ) + .expect_log( + Some(LogLevel::Debug), + Some("handle_operation: Done"), + ) .execute_and_expect(ReturnType::None) .unwrap(); @@ -264,6 +282,10 @@ fn it_stops_on_failure_when_failuremode_is_deny() { Some(LogLevel::Debug), Some("#2 action_set selected some-name"), ) + .expect_log( + Some(LogLevel::Debug), + Some("handle_operation: SendGrpcRequest"), + ) // retrieving tracing headers .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("traceparent")) .returning(None) @@ -282,7 +304,7 @@ fn it_stops_on_failure_when_failuremode_is_deny() { .returning(Ok(42)) .expect_log( Some(LogLevel::Debug), - Some("#2 initiated gRPC call (id# 42)"), + Some("handle_operation: AwaitGrpcResponse"), ) .execute_and_expect(ReturnType::Action(Action::Pause)) .unwrap(); @@ -294,8 +316,10 @@ fn it_stops_on_failure_when_failuremode_is_deny() { Some(LogLevel::Debug), Some(format!("#2 on_grpc_call_response: received gRPC call response: token: 42, status: {status_code}").as_str()), ) - .expect_get_buffer_bytes(Some(BufferType::GrpcReceiveBuffer)) - .returning(Some(&[])) + .expect_log( + Some(LogLevel::Debug), + Some("handle_operation: Die"), + ) .expect_send_local_response(Some(500), None, None, None) .execute_and_expect(ReturnType::None) .unwrap(); diff --git a/tests/failures.rs b/tests/failures.rs index 3e09ea33..42dda487 100644 --- a/tests/failures.rs +++ b/tests/failures.rs @@ -84,6 +84,10 @@ fn it_fails_on_first_action_grpc_call() { Some(LogLevel::Debug), Some("#2 action_set selected some-name"), ) + .expect_log( + Some(LogLevel::Debug), + Some("handle_operation: SendGrpcRequest"), + ) // retrieving tracing headers .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("traceparent")) .returning(None) @@ -104,13 +108,10 @@ fn it_fails_on_first_action_grpc_call() { ) .returning(Err(TestStatus::ParseFailure)) .expect_log( - Some(LogLevel::Error), - Some("OperationError { status: ParseFailure, failure_mode: Deny }"), - ) - .expect_log( - Some(LogLevel::Warn), - Some("OperationError Status: ParseFailure"), + Some(LogLevel::Debug), + Some("handle_operation: failed to send grpc request `ParseFailure`"), ) + .expect_log(Some(LogLevel::Debug), Some("handle_operation: Die")) .expect_send_local_response( Some(500), Some("Internal Server Error.\n"), @@ -219,6 +220,10 @@ fn it_fails_on_second_action_grpc_call() { Some(LogLevel::Debug), Some("#2 action_set selected some-name"), ) + .expect_log( + Some(LogLevel::Debug), + Some("handle_operation: SendGrpcRequest"), + ) // retrieving tracing headers .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("traceparent")) .returning(None) @@ -240,7 +245,7 @@ fn it_fails_on_second_action_grpc_call() { .returning(Ok(42)) .expect_log( Some(LogLevel::Debug), - Some("#2 initiated gRPC call (id# 42)"), + Some("handle_operation: AwaitGrpcResponse"), ) .execute_and_expect(ReturnType::Action(Action::Pause)) .unwrap(); @@ -254,6 +259,14 @@ fn it_fails_on_second_action_grpc_call() { ) .expect_get_buffer_bytes(Some(BufferType::GrpcReceiveBuffer)) .returning(Some(&grpc_response)) + .expect_log( + Some(LogLevel::Debug), + Some("process_response(rl): received OK response"), + ) + .expect_log( + Some(LogLevel::Debug), + Some("handle_operation: SendGrpcRequest"), + ) .expect_grpc_call( Some("does-not-exist"), Some("envoy.service.ratelimit.v3.RateLimitService"), @@ -267,9 +280,10 @@ fn it_fails_on_second_action_grpc_call() { ) .returning(Err(TestStatus::ParseFailure)) .expect_log( - Some(LogLevel::Error), - Some("OperationError { status: ParseFailure, failure_mode: Deny }"), + Some(LogLevel::Debug), + Some("handle_operation: failed to send grpc request `ParseFailure`"), ) + .expect_log(Some(LogLevel::Debug), Some("handle_operation: Die")) .expect_send_local_response( Some(500), Some("Internal Server Error.\n"), @@ -358,6 +372,10 @@ fn it_fails_on_first_action_grpc_response() { Some(LogLevel::Debug), Some("#2 action_set selected some-name"), ) + .expect_log( + Some(LogLevel::Debug), + Some("handle_operation: SendGrpcRequest"), + ) // retrieving tracing headers .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("traceparent")) .returning(None) @@ -379,7 +397,7 @@ fn it_fails_on_first_action_grpc_response() { .returning(Ok(42)) .expect_log( Some(LogLevel::Debug), - Some("#2 initiated gRPC call (id# 42)"), + Some("handle_operation: AwaitGrpcResponse"), ) .execute_and_expect(ReturnType::Action(Action::Pause)) .unwrap(); @@ -391,8 +409,7 @@ fn it_fails_on_first_action_grpc_response() { Some(LogLevel::Debug), Some("#2 on_grpc_call_response: received gRPC call response: token: 42, status: 14"), ) - .expect_get_buffer_bytes(Some(BufferType::GrpcReceiveBuffer)) - .returning(Some(&[])) + .expect_log(Some(LogLevel::Debug), Some("handle_operation: Die")) .expect_send_local_response( Some(500), Some("Internal Server Error.\n"), @@ -500,6 +517,10 @@ fn it_fails_on_second_action_grpc_response() { Some(LogLevel::Debug), Some("#2 action_set selected some-name"), ) + .expect_log( + Some(LogLevel::Debug), + Some("handle_operation: SendGrpcRequest"), + ) // retrieving tracing headers .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("traceparent")) .returning(None) @@ -521,7 +542,7 @@ fn it_fails_on_second_action_grpc_response() { .returning(Ok(first_call_token_id)) .expect_log( Some(LogLevel::Debug), - Some(format!("#2 initiated gRPC call (id# {first_call_token_id})").as_str()), + Some("handle_operation: AwaitGrpcResponse"), ) .execute_and_expect(ReturnType::Action(Action::Pause)) .unwrap(); @@ -540,6 +561,14 @@ fn it_fails_on_second_action_grpc_response() { ) .expect_get_buffer_bytes(Some(BufferType::GrpcReceiveBuffer)) .returning(Some(&grpc_response)) + .expect_log( + Some(LogLevel::Debug), + Some("process_response(rl): received OK response"), + ) + .expect_log( + Some(LogLevel::Debug), + Some("handle_operation: SendGrpcRequest"), + ) .expect_grpc_call( Some("unreachable-cluster"), Some("envoy.service.ratelimit.v3.RateLimitService"), @@ -552,6 +581,10 @@ fn it_fails_on_second_action_grpc_response() { Some(5000), ) .returning(Ok(second_call_token_id)) + .expect_log( + Some(LogLevel::Debug), + Some("handle_operation: AwaitGrpcResponse"), + ) .execute_and_expect(ReturnType::None) .unwrap(); @@ -564,8 +597,7 @@ fn it_fails_on_second_action_grpc_response() { "#2 on_grpc_call_response: received gRPC call response: token: {second_call_token_id}, status: {status_code}" ).as_str()), ) - .expect_get_buffer_bytes(Some(BufferType::GrpcReceiveBuffer)) - .returning(Some(&[])) + .expect_log(Some(LogLevel::Debug), Some("handle_operation: Die")) .expect_send_local_response( Some(500), Some("Internal Server Error.\n"), diff --git a/tests/multi.rs b/tests/multi.rs index 4f9c2946..7b9857c9 100644 --- a/tests/multi.rs +++ b/tests/multi.rs @@ -181,6 +181,10 @@ fn it_performs_authenticated_rate_limiting() { ) .expect_get_property(Some(vec!["source", "port"])) .returning(Some(data::source::port::P_45000)) + .expect_log( + Some(LogLevel::Debug), + Some("handle_operation: SendGrpcRequest"), + ) // retrieving tracing headers .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("traceparent")) .returning(None) @@ -199,7 +203,7 @@ fn it_performs_authenticated_rate_limiting() { .returning(Ok(42)) .expect_log( Some(LogLevel::Debug), - Some("#2 initiated gRPC call (id# 42)"), + Some("handle_operation: AwaitGrpcResponse"), ) .execute_and_expect(ReturnType::Action(Action::Pause)) .unwrap(); @@ -215,7 +219,15 @@ fn it_performs_authenticated_rate_limiting() { .returning(Some(&grpc_response)) .expect_log( Some(LogLevel::Debug), - Some("process_auth_grpc_response: received OkHttpResponse"), + Some("process_response(auth): store_metadata"), + ) + .expect_log( + Some(LogLevel::Debug), + Some("process_response(auth): received OkHttpResponse"), + ) + .expect_log( + Some(LogLevel::Debug), + Some("handle_operation: SendGrpcRequest"), ) .expect_grpc_call( Some("limitador-cluster"), @@ -226,6 +238,10 @@ fn it_performs_authenticated_rate_limiting() { Some(5000), ) .returning(Ok(43)) + .expect_log( + Some(LogLevel::Debug), + Some("handle_operation: AwaitGrpcResponse"), + ) .execute_and_expect(ReturnType::None) .unwrap(); @@ -238,6 +254,11 @@ fn it_performs_authenticated_rate_limiting() { ) .expect_get_buffer_bytes(Some(BufferType::GrpcReceiveBuffer)) .returning(Some(&grpc_response)) + .expect_log( + Some(LogLevel::Debug), + Some("process_response(rl): received OK response"), + ) + .expect_log(Some(LogLevel::Debug), Some("handle_operation: Done")) .execute_and_expect(ReturnType::None) .unwrap(); @@ -378,6 +399,10 @@ fn unauthenticated_does_not_ratelimit() { ) .expect_get_property(Some(vec!["source", "port"])) .returning(Some(data::source::port::P_45000)) + .expect_log( + Some(LogLevel::Debug), + Some("handle_operation: SendGrpcRequest"), + ) // retrieving tracing headers .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("traceparent")) .returning(None) @@ -396,7 +421,7 @@ fn unauthenticated_does_not_ratelimit() { .returning(Ok(42)) .expect_log( Some(LogLevel::Debug), - Some("#2 initiated gRPC call (id# 42)"), + Some("handle_operation: AwaitGrpcResponse"), ) .execute_and_expect(ReturnType::Action(Action::Pause)) .unwrap(); @@ -419,8 +444,13 @@ fn unauthenticated_does_not_ratelimit() { .returning(Some(&grpc_response)) .expect_log( Some(LogLevel::Debug), - Some("process_auth_grpc_response: received DeniedHttpResponse"), + Some("process_response(auth): store_metadata"), ) + .expect_log( + Some(LogLevel::Debug), + Some("process_response(auth): received DeniedHttpResponse"), + ) + .expect_log(Some(LogLevel::Debug), Some("handle_operation: Die")) .expect_send_local_response( Some(401), None, @@ -638,6 +668,10 @@ fn authenticated_one_ratelimit_action_matches() { ) .expect_get_property(Some(vec!["source", "port"])) .returning(Some(data::source::port::P_45000)) + .expect_log( + Some(LogLevel::Debug), + Some("handle_operation: SendGrpcRequest"), + ) // retrieving tracing headers .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("traceparent")) .returning(None) @@ -656,7 +690,7 @@ fn authenticated_one_ratelimit_action_matches() { .returning(Ok(42)) .expect_log( Some(LogLevel::Debug), - Some("#2 initiated gRPC call (id# 42)"), + Some("handle_operation: AwaitGrpcResponse"), ) .execute_and_expect(ReturnType::Action(Action::Pause)) .unwrap(); @@ -672,7 +706,11 @@ fn authenticated_one_ratelimit_action_matches() { .returning(Some(&grpc_response)) .expect_log( Some(LogLevel::Debug), - Some("process_auth_grpc_response: received OkHttpResponse"), + Some("process_response(auth): store_metadata"), + ) + .expect_log( + Some(LogLevel::Debug), + Some("process_response(auth): received OkHttpResponse"), ) // conditions checks .expect_log( @@ -687,6 +725,10 @@ fn authenticated_one_ratelimit_action_matches() { ) .expect_get_property(Some(vec!["source", "address"])) .returning(Some("1.2.3.4:80".as_bytes())) + .expect_log( + Some(LogLevel::Debug), + Some("handle_operation: SendGrpcRequest"), + ) .expect_grpc_call( Some("limitador-cluster"), Some("envoy.service.ratelimit.v3.RateLimitService"), @@ -696,6 +738,10 @@ fn authenticated_one_ratelimit_action_matches() { Some(5000), ) .returning(Ok(43)) + .expect_log( + Some(LogLevel::Debug), + Some("handle_operation: AwaitGrpcResponse"), + ) .execute_and_expect(ReturnType::None) .unwrap(); @@ -708,6 +754,11 @@ fn authenticated_one_ratelimit_action_matches() { ) .expect_get_buffer_bytes(Some(BufferType::GrpcReceiveBuffer)) .returning(Some(&grpc_response)) + .expect_log( + Some(LogLevel::Debug), + Some("process_response(rl): received OK response"), + ) + .expect_log(Some(LogLevel::Debug), Some("handle_operation: Done")) .execute_and_expect(ReturnType::None) .unwrap(); diff --git a/tests/rate_limited.rs b/tests/rate_limited.rs index 8c897086..78e322d1 100644 --- a/tests/rate_limited.rs +++ b/tests/rate_limited.rs @@ -53,10 +53,6 @@ fn it_loads() { .expect_log(Some(LogLevel::Debug), Some("#2 on_http_request_headers")) .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":authority")) .returning(Some("cars.toystore.com")) - .expect_log( - Some(LogLevel::Debug), - Some("#2 allowing request to pass because zero descriptors generated"), - ) .execute_and_expect(ReturnType::Action(Action::Continue)) .unwrap(); @@ -168,6 +164,10 @@ fn it_limits() { Some(LogLevel::Debug), Some("#2 action_set selected some-name"), ) + .expect_log( + Some(LogLevel::Debug), + Some("handle_operation: SendGrpcRequest"), + ) // retrieving tracing headers .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("traceparent")) .returning(None) @@ -186,7 +186,7 @@ fn it_limits() { .returning(Ok(42)) .expect_log( Some(LogLevel::Debug), - Some("#2 initiated gRPC call (id# 42)"), + Some("handle_operation: AwaitGrpcResponse"), ) .execute_and_expect(ReturnType::Action(Action::Pause)) .unwrap(); @@ -200,6 +200,11 @@ fn it_limits() { ) .expect_get_buffer_bytes(Some(BufferType::GrpcReceiveBuffer)) .returning(Some(&grpc_response)) + .expect_log( + Some(LogLevel::Debug), + Some("process_response(rl): received OK response"), + ) + .expect_log(Some(LogLevel::Debug), Some("handle_operation: Done")) .execute_and_expect(ReturnType::None) .unwrap(); @@ -311,6 +316,10 @@ fn it_passes_additional_headers() { Some(LogLevel::Debug), Some("#2 action_set selected some-name"), ) + .expect_log( + Some(LogLevel::Debug), + Some("handle_operation: SendGrpcRequest"), + ) // retrieving tracing headers .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("traceparent")) .returning(None) @@ -329,7 +338,7 @@ fn it_passes_additional_headers() { .returning(Ok(42)) .expect_log( Some(LogLevel::Debug), - Some("#2 initiated gRPC call (id# 42)"), + Some("handle_operation: AwaitGrpcResponse"), ) .execute_and_expect(ReturnType::Action(Action::Pause)) .unwrap(); @@ -347,6 +356,12 @@ fn it_passes_additional_headers() { ) .expect_get_buffer_bytes(Some(BufferType::GrpcReceiveBuffer)) .returning(Some(&grpc_response)) + .expect_log( + Some(LogLevel::Debug), + Some("process_response(rl): received OK response"), + ) + .expect_log(Some(LogLevel::Debug), Some("handle_operation: AddHeaders")) + .expect_log(Some(LogLevel::Debug), Some("handle_operation: Done")) .execute_and_expect(ReturnType::None) .unwrap(); @@ -444,6 +459,10 @@ fn it_rate_limits_with_empty_predicates() { Some(LogLevel::Debug), Some("#2 action_set selected some-name"), ) + .expect_log( + Some(LogLevel::Debug), + Some("handle_operation: SendGrpcRequest"), + ) // retrieving tracing headers .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("traceparent")) .returning(None) @@ -462,7 +481,7 @@ fn it_rate_limits_with_empty_predicates() { .returning(Ok(42)) .expect_log( Some(LogLevel::Debug), - Some("#2 initiated gRPC call (id# 42)"), + Some("handle_operation: AwaitGrpcResponse"), ) .execute_and_expect(ReturnType::Action(Action::Pause)) .unwrap(); @@ -476,6 +495,11 @@ fn it_rate_limits_with_empty_predicates() { ) .expect_get_buffer_bytes(Some(BufferType::GrpcReceiveBuffer)) .returning(Some(&grpc_response)) + .expect_log( + Some(LogLevel::Debug), + Some("process_response(rl): received OK response"), + ) + .expect_log(Some(LogLevel::Debug), Some("handle_operation: Done")) .execute_and_expect(ReturnType::None) .unwrap(); @@ -568,8 +592,9 @@ fn it_does_not_rate_limits_when_predicates_does_not_match() { .returning(Some(data::request::path::ADMIN)) .expect_log( Some(LogLevel::Debug), - Some("grpc_message_request: empty descriptors"), + Some("build_message(rl): empty descriptors"), ) + .expect_log(Some(LogLevel::Debug), Some("handle_operation: Done")) .execute_and_expect(ReturnType::Action(Action::Continue)) .unwrap(); @@ -682,6 +707,10 @@ fn it_folds_subsequent_actions_to_limitador_into_a_single_one() { Some(LogLevel::Debug), Some("#2 action_set selected some-name"), ) + .expect_log( + Some(LogLevel::Debug), + Some("handle_operation: SendGrpcRequest"), + ) // retrieving tracing headers .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("traceparent")) .returning(None) @@ -700,7 +729,7 @@ fn it_folds_subsequent_actions_to_limitador_into_a_single_one() { .returning(Ok(42)) .expect_log( Some(LogLevel::Debug), - Some("#2 initiated gRPC call (id# 42)"), + Some("handle_operation: AwaitGrpcResponse"), ) .execute_and_expect(ReturnType::Action(Action::Pause)) .unwrap(); @@ -714,6 +743,11 @@ fn it_folds_subsequent_actions_to_limitador_into_a_single_one() { ) .expect_get_buffer_bytes(Some(BufferType::GrpcReceiveBuffer)) .returning(Some(&grpc_response)) + .expect_log( + Some(LogLevel::Debug), + Some("process_response(rl): received OK response"), + ) + .expect_log(Some(LogLevel::Debug), Some("handle_operation: Done")) .execute_and_expect(ReturnType::None) .unwrap(); diff --git a/tests/remote_address.rs b/tests/remote_address.rs index 3d25e99a..dc75a149 100644 --- a/tests/remote_address.rs +++ b/tests/remote_address.rs @@ -102,6 +102,10 @@ fn it_limits_based_on_source_address() { ) .expect_get_property(Some(vec!["source", "address"])) .returning(Some(data::source::ADDRESS)) + .expect_log( + Some(LogLevel::Debug), + Some("handle_operation: SendGrpcRequest"), + ) // retrieving tracing headers .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("traceparent")) .returning(None) @@ -124,7 +128,7 @@ fn it_limits_based_on_source_address() { .returning(Ok(42)) .expect_log( Some(LogLevel::Debug), - Some("#2 initiated gRPC call (id# 42)"), + Some("handle_operation: AwaitGrpcResponse"), ) .execute_and_expect(ReturnType::Action(Action::Pause)) .unwrap(); @@ -138,6 +142,11 @@ fn it_limits_based_on_source_address() { ) .expect_get_buffer_bytes(Some(BufferType::GrpcReceiveBuffer)) .returning(Some(&grpc_response)) + .expect_log( + Some(LogLevel::Debug), + Some("process_response(rl): received OK response"), + ) + .expect_log(Some(LogLevel::Debug), Some("handle_operation: Done")) .execute_and_expect(ReturnType::None) .unwrap(); diff --git a/utils/kustomize/limitador/limitador.yaml b/utils/kustomize/limitador/limitador.yaml index 93823123..523b28d4 100644 --- a/utils/kustomize/limitador/limitador.yaml +++ b/utils/kustomize/limitador/limitador.yaml @@ -4,6 +4,7 @@ kind: Limitador metadata: name: limitador spec: + image: quay.io/kuadrant/limitador:v1.6.0 verbosity: 3 listener: http: