diff --git a/src/hyperlight_guest/src/error.rs b/src/hyperlight_guest/src/error.rs index db7e01924..e07407f70 100644 --- a/src/hyperlight_guest/src/error.rs +++ b/src/hyperlight_guest/src/error.rs @@ -51,3 +51,294 @@ impl From for HyperlightGuestError { } } } + +/// Extension trait to add context to `Option` and `Result` types in guest code, +/// converting them to `Result`. +/// +/// This is similar to anyhow::Context. +pub trait GuestErrorContext { + type Ok; + /// Adds context to the error if `self` is `None` or `Err`. + fn context(self, ctx: impl Into) -> Result; + /// Adds context and a specific error code to the error if `self` is `None` or `Err`. + fn context_and_code(self, ec: ErrorCode, ctx: impl Into) -> Result; + /// Lazily adds context to the error if `self` is `None` or `Err`. + /// + /// This is useful if constructing the context message is expensive. + fn with_context>(self, ctx: impl FnOnce() -> S) -> Result; + /// Lazily adds context and a specific error code to the error if `self` is `None` or `Err`. + /// + /// This is useful if constructing the context message is expensive. + fn with_context_and_code>( + self, + ec: ErrorCode, + ctx: impl FnOnce() -> S, + ) -> Result; +} + +impl GuestErrorContext for Option { + type Ok = T; + #[inline] + fn context(self, ctx: impl Into) -> Result { + self.with_context_and_code(ErrorCode::GuestError, || ctx) + } + #[inline] + fn context_and_code(self, ec: ErrorCode, ctx: impl Into) -> Result { + self.with_context_and_code(ec, || ctx) + } + #[inline] + fn with_context>(self, ctx: impl FnOnce() -> S) -> Result { + self.with_context_and_code(ErrorCode::GuestError, ctx) + } + #[inline] + fn with_context_and_code>( + self, + ec: ErrorCode, + ctx: impl FnOnce() -> S, + ) -> Result { + match self { + Some(s) => Ok(s), + None => Err(HyperlightGuestError::new(ec, ctx().into())), + } + } +} + +impl GuestErrorContext for core::result::Result { + type Ok = T; + #[inline] + fn context(self, ctx: impl Into) -> Result { + self.with_context_and_code(ErrorCode::GuestError, || ctx) + } + #[inline] + fn context_and_code(self, ec: ErrorCode, ctx: impl Into) -> Result { + self.with_context_and_code(ec, || ctx) + } + #[inline] + fn with_context>(self, ctx: impl FnOnce() -> S) -> Result { + self.with_context_and_code(ErrorCode::GuestError, ctx) + } + #[inline] + fn with_context_and_code>( + self, + ec: ErrorCode, + ctx: impl FnOnce() -> S, + ) -> Result { + match self { + Ok(s) => Ok(s), + Err(e) => Err(HyperlightGuestError::new( + ec, + format!("{}.\nCaused by: {e:?}", ctx().into()), + )), + } + } +} + +/// Macro to return early with a `Err(HyperlightGuestError)`. +/// Usage: +/// ```ignore +/// bail!(ErrorCode::UnknownError => "An error occurred: {}", details); +/// // or +/// bail!("A guest error occurred: {}", details); // defaults to ErrorCode::GuestError +/// ``` +#[macro_export] +macro_rules! bail { + ($ec:expr => $($msg:tt)*) => { + return ::core::result::Result::Err($crate::error::HyperlightGuestError::new($ec, ::alloc::format!($($msg)*))); + }; + ($($msg:tt)*) => { + $crate::bail!($crate::error::ErrorCode::GuestError => $($msg)*); + }; +} + +/// Macro to ensure a condition is true, otherwise returns early with a `Err(HyperlightGuestError)`. +/// Usage: +/// ```ignore +/// ensure!(1 + 1 == 3, ErrorCode::UnknownError => "Maths is broken: {}", details); +/// // or +/// ensure!(1 + 1 == 3, "Maths is broken: {}", details); // defaults to ErrorCode::GuestError +/// // or +/// ensure!(1 + 1 == 3); // defaults to ErrorCode::GuestError with a default message +/// ``` +#[macro_export] +macro_rules! ensure { + ($cond:expr) => { + if !($cond) { + $crate::bail!(::core::concat!("Condition failed: `", ::core::stringify!($cond), "`")); + } + }; + ($cond:expr, $ec:expr => $($msg:tt)*) => { + if !($cond) { + $crate::bail!($ec => ::core::concat!("{}\nCaused by failed condition: `", ::core::stringify!($cond), "`"), ::core::format_args!($($msg)*)); + } + }; + ($cond:expr, $($msg:tt)*) => { + $crate::ensure!($cond, $crate::error::ErrorCode::GuestError => $($msg)*); + }; +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_context_option_some() { + let value: Option = Some(42); + let result = value.context("Should be Some"); + assert_eq!(result.unwrap(), 42); + } + + #[test] + fn test_context_option_none() { + let value: Option = None; + let result = value.context("Should be Some"); + let err = result.unwrap_err(); + assert_eq!(err.kind, ErrorCode::GuestError); + assert_eq!(err.message, "Should be Some"); + } + + #[test] + fn test_context_and_code_option_none() { + let value: Option = None; + let result = value.context_and_code(ErrorCode::MallocFailed, "Should be Some"); + let err = result.unwrap_err(); + assert_eq!(err.kind, ErrorCode::MallocFailed); + assert_eq!(err.message, "Should be Some"); + } + + #[test] + fn test_with_context_option_none() { + let value: Option = None; + let result = value.with_context(|| "Lazy context message"); + let err = result.unwrap_err(); + assert_eq!(err.kind, ErrorCode::GuestError); + assert_eq!(err.message, "Lazy context message"); + } + + #[test] + fn test_with_context_and_code_option_none() { + let value: Option = None; + let result = + value.with_context_and_code(ErrorCode::MallocFailed, || "Lazy context message"); + let err = result.unwrap_err(); + assert_eq!(err.kind, ErrorCode::MallocFailed); + assert_eq!(err.message, "Lazy context message"); + } + + #[test] + fn test_context_result_ok() { + let value: core::result::Result = Ok(42); + let result = value.context("Should be Ok"); + assert_eq!(result.unwrap(), 42); + } + + #[test] + fn test_context_result_err() { + let value: core::result::Result = Err("Some error"); + let result = value.context("Should be Ok"); + let err = result.unwrap_err(); + assert_eq!(err.kind, ErrorCode::GuestError); + assert_eq!(err.message, "Should be Ok.\nCaused by: \"Some error\""); + } + + #[test] + fn test_context_and_code_result_err() { + let value: core::result::Result = Err("Some error"); + let result = value.context_and_code(ErrorCode::MallocFailed, "Should be Ok"); + let err = result.unwrap_err(); + assert_eq!(err.kind, ErrorCode::MallocFailed); + assert_eq!(err.message, "Should be Ok.\nCaused by: \"Some error\""); + } + + #[test] + fn test_with_context_result_err() { + let value: core::result::Result = Err("Some error"); + let result = value.with_context(|| "Lazy context message"); + let err = result.unwrap_err(); + assert_eq!(err.kind, ErrorCode::GuestError); + assert_eq!( + err.message, + "Lazy context message.\nCaused by: \"Some error\"" + ); + } + + #[test] + fn test_with_context_and_code_result_err() { + let value: core::result::Result = Err("Some error"); + let result = + value.with_context_and_code(ErrorCode::MallocFailed, || "Lazy context message"); + let err = result.unwrap_err(); + assert_eq!(err.kind, ErrorCode::MallocFailed); + assert_eq!( + err.message, + "Lazy context message.\nCaused by: \"Some error\"" + ); + } + + #[test] + fn test_bail_macro() { + let result: Result = (|| { + bail!("A guest error occurred"); + })(); + let err = result.unwrap_err(); + assert_eq!(err.kind, ErrorCode::GuestError); + assert_eq!(err.message, "A guest error occurred"); + } + + #[test] + fn test_bail_macro_with_error_code() { + let result: Result = (|| { + bail!(ErrorCode::MallocFailed => "Memory allocation failed"); + })(); + let err = result.unwrap_err(); + assert_eq!(err.kind, ErrorCode::MallocFailed); + assert_eq!(err.message, "Memory allocation failed"); + } + + #[test] + fn test_ensure_macro_pass() { + let result: Result = (|| { + ensure!(1 + 1 == 2, "Math works"); + Ok(42) + })(); + assert_eq!(result.unwrap(), 42); + } + + #[test] + fn test_ensure_macro_fail() { + let result: Result = (|| { + ensure!(1 + 1 == 3, "Math is broken"); + Ok(42) + })(); + let err = result.unwrap_err(); + assert_eq!(err.kind, ErrorCode::GuestError); + assert_eq!( + err.message, + "Math is broken\nCaused by failed condition: `1 + 1 == 3`" + ); + } + + #[test] + fn test_ensure_macro_fail_no_message() { + let result: Result = (|| { + ensure!(1 + 1 == 3); + Ok(42) + })(); + let err = result.unwrap_err(); + assert_eq!(err.kind, ErrorCode::GuestError); + assert_eq!(err.message, "Condition failed: `1 + 1 == 3`"); + } + + #[test] + fn test_ensure_macro_fail_with_error_code() { + let result: Result = (|| { + ensure!(1 + 1 == 3, ErrorCode::UnknownError => "Math is broken"); + Ok(42) + })(); + let err = result.unwrap_err(); + assert_eq!(err.kind, ErrorCode::UnknownError); + assert_eq!( + err.message, + "Math is broken\nCaused by failed condition: `1 + 1 == 3`" + ); + } +}