From 317fd46b1fe04f63b169b4b00b44d31ffbfba638 Mon Sep 17 00:00:00 2001 From: csyJoy Date: Wed, 25 Sep 2024 21:52:12 +0800 Subject: [PATCH 1/5] feat(async): Asynchronous quote generation phase 1. Asynchronous tdx sgx quote generation phase via `tokio::spawn_block`. 2. Propagate `async` to function that call or implicitly call asynchronous interface. 3. add asynchronous test. --- Cargo.lock | 198 ++++++++++++++++++++++-- rats-rs/src/cert/create.rs | 27 ++-- rats-rs/src/cert/verify.rs | 27 ++-- rats-rs/src/tee/mod.rs | 26 ++-- rats-rs/src/tee/sgx_dcap/attester.rs | 36 ++++- rats-rs/src/tee/sgx_dcap/mod.rs | 9 +- rats-rs/src/tee/tdx/attester.rs | 37 ++++- rats-rs/src/tee/tdx/mod.rs | 8 +- rats-rs/src/transport/spdm/requester.rs | 12 +- rats-rs/src/transport/spdm/responder.rs | 18 ++- rats-rs/src/transport/tls/client.rs | 43 +++-- rats-rs/src/transport/tls/server.rs | 39 +++-- 12 files changed, 377 insertions(+), 103 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 1c19917..4227679 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -127,6 +127,33 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" +[[package]] +name = "aws-lc-rs" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2f95446d919226d587817a7d21379e6eb099b97b45110a7f272a444ca5c54070" +dependencies = [ + "aws-lc-sys", + "mirai-annotations", + "paste", + "zeroize", +] + +[[package]] +name = "aws-lc-sys" +version = "0.21.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b3ddc4a5b231dd6958b140ff3151b6412b3f4321fab354f399eec8f14b06df62" +dependencies = [ + "bindgen 0.69.4", + "cc", + "cmake", + "dunce", + "fs_extra", + "libc", + "paste", +] + [[package]] name = "backtrace" version = "0.3.69" @@ -207,6 +234,29 @@ dependencies = [ "which", ] +[[package]] +name = "bindgen" +version = "0.69.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a00dc851838a2120612785d195287475a3ac45514741da670b735818822129a0" +dependencies = [ + "bitflags 2.5.0", + "cexpr", + "clang-sys", + "itertools", + "lazy_static", + "lazycell", + "log", + "prettyplease", + "proc-macro2", + "quote", + "regex", + "rustc-hash", + "shlex", + "syn", + "which", +] + [[package]] name = "bit_field" version = "0.10.2" @@ -259,11 +309,13 @@ checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223" [[package]] name = "cc" -version = "1.0.83" +version = "1.1.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1174fb0b6ec23863f8b971027804a42614e347eafb0a95bf0b12cdae21fc4d0" +checksum = "07b1695e2c7e8fc85310cde85aeaab7e3097f593c91d209d3f9df76c928100f0" dependencies = [ + "jobserver", "libc", + "shlex", ] [[package]] @@ -374,6 +426,15 @@ version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "98cc8fbded0c607b7ba9dd60cd98df59af97e84d24e49c8557331cfc26d301ce" +[[package]] +name = "cmake" +version = "0.1.51" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fb1e43aa7fd152b1f968787f7dbcdeb306d1867ff373c69955211876c053f91a" +dependencies = [ + "cc", +] + [[package]] name = "codec" version = "0.2.2" @@ -467,6 +528,16 @@ dependencies = [ "typenum", ] +[[package]] +name = "ctor" +version = "0.2.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "edb49164822f3ee45b17acd4a208cfc1251410cf0cad9a833234c9890774dd9f" +dependencies = [ + "quote", + "syn", +] + [[package]] name = "der" version = "0.7.8" @@ -512,6 +583,12 @@ dependencies = [ "subtle", ] +[[package]] +name = "dunce" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813" + [[package]] name = "ecdsa" version = "0.16.9" @@ -620,6 +697,12 @@ version = "0.4.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d52a7e408202050813e6f1d9addadcaafef3dca7530c7ddfb005d4081cce6779" +[[package]] +name = "fs_extra" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" + [[package]] name = "futures" version = "0.3.30" @@ -843,13 +926,22 @@ version = "1.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c" +[[package]] +name = "jobserver" +version = "0.1.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48d1dbcbbeb6a7fec7e059840aa538bd62aaccf972c7346c4d9d2059312853d0" +dependencies = [ + "libc", +] + [[package]] name = "lazy_static" -version = "1.4.0" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" +checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" dependencies = [ - "spin 0.5.2", + "spin", ] [[package]] @@ -945,6 +1037,12 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "mirai-annotations" +version = "1.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c9be0862c1b3f26a88803c4a49de6889c10e608b3ee9344e6ef5b45fb37ad3d1" + [[package]] name = "nom" version = "7.1.3" @@ -1043,6 +1141,19 @@ version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" +[[package]] +name = "openssl-hook" +version = "0.1.0" +dependencies = [ + "anyhow", + "ctor", + "env_logger 0.11.2", + "lazy_static", + "libc", + "log", + "rats-rs", +] + [[package]] name = "openssl-sys" version = "0.9.102" @@ -1091,6 +1202,12 @@ dependencies = [ "windows-targets 0.48.5", ] +[[package]] +name = "paste" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" + [[package]] name = "peeking_take_while" version = "0.1.2" @@ -1242,6 +1359,8 @@ dependencies = [ "intel-dcap", "intel-tee-quote-verification-rs", "itertools", + "lazy_static", + "libc", "log", "maybe-async", "occlum_dcap", @@ -1257,10 +1376,11 @@ dependencies = [ "sha2", "signature", "spdmlib", - "spin 0.9.8", + "spin", "tdx-attest-rs", "thiserror", "tokio", + "tokio-rustls", "x509-cert", "zeroize", ] @@ -1342,7 +1462,7 @@ dependencies = [ "cc", "getrandom", "libc", - "spin 0.9.8", + "spin", "untrusted", "windows-sys 0.48.0", ] @@ -1393,6 +1513,39 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "rustls" +version = "0.23.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2dabaac7466917e566adb06783a81ca48944c6898a1b08b9374106dd671f4c8" +dependencies = [ + "aws-lc-rs", + "log", + "once_cell", + "rustls-pki-types", + "rustls-webpki", + "subtle", + "zeroize", +] + +[[package]] +name = "rustls-pki-types" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc0a2ce646f8655401bb81e7927b812614bd5d91dbc968696be50603510fcaf0" + +[[package]] +name = "rustls-webpki" +version = "0.102.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "64ca1bc8749bd4cf37b5ce386cc146580777b4e8572c7b97baf22c83f444bee9" +dependencies = [ + "aws-lc-rs", + "ring", + "rustls-pki-types", + "untrusted", +] + [[package]] name = "rustversion" version = "1.0.14" @@ -1562,19 +1715,13 @@ dependencies = [ "ring", "serde", "serde_json", - "spin 0.9.8", + "spin", "sys_time", "untrusted", "webpki", "zeroize", ] -[[package]] -name = "spin" -version = "0.5.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" - [[package]] name = "spin" version = "0.9.8" @@ -1703,6 +1850,18 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3" +[[package]] +name = "tls" +version = "0.1.0" +dependencies = [ + "anyhow", + "clap 4.5.4", + "env_logger 0.11.2", + "log", + "rand", + "rats-rs", +] + [[package]] name = "tls_codec" version = "0.4.1" @@ -1754,6 +1913,17 @@ dependencies = [ "syn", ] +[[package]] +name = "tokio-rustls" +version = "0.26.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4" +dependencies = [ + "rustls", + "rustls-pki-types", + "tokio", +] + [[package]] name = "typenum" version = "1.17.0" diff --git a/rats-rs/src/cert/create.rs b/rats-rs/src/cert/create.rs index 0eb106c..9efdccf 100644 --- a/rats-rs/src/cert/create.rs +++ b/rats-rs/src/cert/create.rs @@ -70,9 +70,10 @@ impl CertBuilder { self } - pub fn build(&self, private_key_algo: AsymmetricAlgo) -> Result> { + #[maybe_async::maybe_async] + pub async fn build(&self, private_key_algo: AsymmetricAlgo) -> Result> { let key = DefaultCrypto::gen_private_key(private_key_algo)?; - let (cert, evidence) = self.build_with_private_key_inner(&key)?; + let (cert, evidence) = self.build_with_private_key_inner(&key).await?; Ok(CertBundle { private_key: key, @@ -81,11 +82,12 @@ impl CertBuilder { }) } - pub fn build_with_private_key( + #[maybe_async::maybe_async] + pub async fn build_with_private_key( &self, key: &AsymmetricPrivateKey, ) -> Result> { - let (cert, evidence) = self.build_with_private_key_inner(&key)?; + let (cert, evidence) = self.build_with_private_key_inner(&key).await?; Ok(CertBundle { private_key: key.clone(), @@ -94,7 +96,8 @@ impl CertBuilder { }) } - fn build_with_private_key_inner( + #[maybe_async::maybe_async] + async fn build_with_private_key_inner( &self, key: &AsymmetricPrivateKey, ) -> Result<(Certificate, A::Evidence)> { @@ -115,7 +118,7 @@ impl CertBuilder { let claims_buffer_hash = DefaultCrypto::hash(HashAlgo::Sha256, &claims_buffer); /* Generate evidence buffer */ - let evidence = self.attester.get_evidence(&claims_buffer_hash)?; + let evidence = self.attester.get_evidence(&claims_buffer_hash).await?; let evidence_buffer = generate_evidence_buffer_with_tag( evidence.get_dice_cbor_tag(), evidence.get_dice_raw_evidence(), @@ -146,8 +149,10 @@ pub mod tests { #[allow(unused_imports)] use super::*; - #[test] - fn test_get_attestation_certificate() -> Result<()> { + #[cfg_attr(feature = "is-sync", test)] + #[cfg_attr(not(feature = "is-sync"), tokio::test)] + #[maybe_async::maybe_async] + async fn test_get_attestation_certificate() -> Result<()> { if TeeType::detect_env() == None { /* skip */ return Ok(()); @@ -161,7 +166,8 @@ pub mod tests { let attester = AutoAttester::new(); let cert_bundle = CertBuilder::new(attester, HashAlgo::Sha256) .with_claims(claims.clone()) - .build(AsymmetricAlgo::P256)?; + .build(AsymmetricAlgo::P256) + .await?; println!("generated cert:\n{}", cert_bundle.cert_to_pem()?); println!( @@ -174,7 +180,8 @@ pub mod tests { let key = DefaultCrypto::gen_private_key(AsymmetricAlgo::P256)?; let cert_bundle = CertBuilder::new(attester, HashAlgo::Sha256) .with_claims(claims) - .build_with_private_key(&key)?; + .build_with_private_key(&key) + .await?; println!("generated cert:\n{}", cert_bundle.cert_to_pem()?); println!( diff --git a/rats-rs/src/cert/verify.rs b/rats-rs/src/cert/verify.rs index e7e045f..e73d723 100644 --- a/rats-rs/src/cert/verify.rs +++ b/rats-rs/src/cert/verify.rs @@ -234,8 +234,10 @@ pub mod tests { #[allow(unused_imports)] use super::*; - #[test] - fn test_verify_cert_der() -> Result<()> { + #[cfg_attr(feature = "is-sync", test)] + #[cfg_attr(not(feature = "is-sync"), tokio::test)] + #[maybe_async::maybe_async] + async fn test_verify_cert_der() -> Result<()> { if TeeType::detect_env() == None { /* skip */ return Ok(()); @@ -249,7 +251,8 @@ pub mod tests { let attester = AutoAttester::new(); let cert_bundle = CertBuilder::new(attester, HashAlgo::Sha256) .with_claims(claims.clone()) - .build(AsymmetricAlgo::P256)?; + .build(AsymmetricAlgo::P256) + .await?; let cert = cert_bundle.cert_to_der()?; let parsed_claims = verify_cert_der(&cert)?; @@ -263,8 +266,10 @@ pub mod tests { Ok(()) } - #[test] - fn test_verify_attestation_certificate() -> Result<()> { + #[cfg_attr(feature = "is-sync", test)] + #[cfg_attr(not(feature = "is-sync"), tokio::test)] + #[maybe_async::maybe_async] + async fn test_verify_attestation_certificate() -> Result<()> { if TeeType::detect_env() == None { /* skip */ return Ok(()); @@ -278,7 +283,8 @@ pub mod tests { let attester = AutoAttester::new(); let cert_bundle = CertBuilder::new(attester, HashAlgo::Sha256) .with_claims(claims.clone()) - .build(AsymmetricAlgo::P256)?; + .build(AsymmetricAlgo::P256) + .await?; let cert = cert_bundle.cert_to_der()?; assert_eq!( @@ -303,8 +309,10 @@ pub mod tests { Ok(()) } - #[test] - fn test_verify_attestation_certificate_with_claims_overriding() -> Result<()> { + #[cfg_attr(feature = "is-sync", test)] + #[cfg_attr(not(feature = "is-sync"), tokio::test)] + #[maybe_async::maybe_async] + async fn test_verify_attestation_certificate_with_claims_overriding() -> Result<()> { if TeeType::detect_env() == None { /* skip */ return Ok(()); @@ -320,7 +328,8 @@ pub mod tests { let attester = AutoAttester::new(); let cert_bundle = CertBuilder::new(attester, HashAlgo::Sha256) .with_claims(claims.clone()) - .build(AsymmetricAlgo::P256)?; + .build(AsymmetricAlgo::P256) + .await?; let cert = cert_bundle.cert_to_der()?; assert_eq!( diff --git a/rats-rs/src/tee/mod.rs b/rats-rs/src/tee/mod.rs index 6060473..28530d6 100644 --- a/rats-rs/src/tee/mod.rs +++ b/rats-rs/src/tee/mod.rs @@ -26,11 +26,12 @@ pub trait GenericEvidence: Any { } /// Trait representing a generic attester. +#[maybe_async::maybe_async] pub trait GenericAttester { type Evidence: GenericEvidence; /// Generate evidence based on the provided report data. - fn get_evidence(&self, report_data: &[u8]) -> Result; + async fn get_evidence(&self, report_data: &[u8]) -> Result; } /// Trait representing a generic verifier. @@ -122,10 +123,11 @@ impl AutoAttester{ } } +#[maybe_async::maybe_async] impl GenericAttester for AutoAttester { type Evidence = AutoEvidence; - fn get_evidence(&self, report_data: &[u8]) -> Result { + async fn get_evidence(&self, report_data: &[u8]) -> Result { let tee_type = TeeType::detect_env(); if let Some(tee_type) = tee_type { @@ -133,12 +135,12 @@ impl GenericAttester for AutoAttester { #[cfg(feature = "attester-sgx-dcap")] TeeType::SgxDcap => { let attester = sgx_dcap::attester::SgxDcapAttester::new(); - attester.get_evidence(report_data).map(|ev|AutoEvidence(Box::new(ev) as Box)) + attester.get_evidence(report_data).await.map(|ev|AutoEvidence(Box::new(ev) as Box)) } #[cfg(feature = "attester-tdx")] TeeType::Tdx => { let attester = tdx::attester::TdxAttester::new(); - attester.get_evidence(report_data).map(|ev|AutoEvidence(Box::new(ev) as Box)) + attester.get_evidence(report_data).await.map(|ev|AutoEvidence(Box::new(ev) as Box)) } #[allow(unreachable_patterns)] _ => { @@ -215,8 +217,10 @@ pub mod tests { use super::*; - #[test] - fn test_auto_attester_and_auto_verifier_on_sgx_dcap() -> Result<()> { + #[cfg_attr(feature = "is-sync", test)] + #[cfg_attr(not(feature = "is-sync"), tokio::test)] + #[maybe_async::maybe_async] + async fn test_auto_attester_and_auto_verifier_on_sgx_dcap() -> Result<()> { if TeeType::detect_env() != Some(TeeType::SgxDcap) { /* skip */ return Ok(()); @@ -224,7 +228,7 @@ pub mod tests { let report_data = b"test_report_data"; let attester = AutoAttester::new(); - let evidence = attester.get_evidence(report_data)?; + let evidence = attester.get_evidence(report_data).await?; assert_eq!(evidence.get_tee_type(), TeeType::SgxDcap); let verifier = AutoVerifier::new(); assert_eq!(verifier.verify_evidence(&evidence, report_data), Ok(())); @@ -239,8 +243,10 @@ pub mod tests { } - #[test] - fn test_auto_attester_and_auto_verifier_on_non_tee() -> Result<()> { + #[cfg_attr(feature = "is-sync", test)] + #[cfg_attr(not(feature = "is-sync"), tokio::test)] + #[maybe_async::maybe_async] + async fn test_auto_attester_and_auto_verifier_on_non_tee() -> Result<()> { if TeeType::detect_env() != None { /* skip */ return Ok(()); @@ -248,7 +254,7 @@ pub mod tests { let report_data = b"test_report_data"; let attester = AutoAttester::new(); - let res = attester.get_evidence(report_data); + let res = attester.get_evidence(report_data).await; assert!(res.is_err()); let Err(err) = res else {panic!()}; assert_eq!(err.get_kind(), ErrorKind::UnsupportedTeeType); diff --git a/rats-rs/src/tee/sgx_dcap/attester.rs b/rats-rs/src/tee/sgx_dcap/attester.rs index 4d09226..5d087d2 100644 --- a/rats-rs/src/tee/sgx_dcap/attester.rs +++ b/rats-rs/src/tee/sgx_dcap/attester.rs @@ -6,6 +6,9 @@ use crate::errors::*; use crate::tee::GenericAttester; use occlum_dcap::{sgx_report_data_t, DcapQuote}; +#[cfg(feature = "async-tokio")] +use tokio::task; + pub struct SgxDcapAttester {} impl SgxDcapAttester { @@ -14,10 +17,11 @@ impl SgxDcapAttester { } } +#[maybe_async::maybe_async] impl GenericAttester for SgxDcapAttester { type Evidence = SgxDcapEvidence; - fn get_evidence(&self, report_data: &[u8]) -> Result { + async fn get_evidence(&self, report_data: &[u8]) -> Result { if cfg!(feature = "attester-sgx-dcap-occlum") { if report_data.len() > 64 { Err(Error::kind_with_msg( @@ -34,13 +38,29 @@ impl GenericAttester for SgxDcapAttester { let mut sgx_report_data = sgx_report_data_t::default(); sgx_report_data.d[..report_data.len()].clone_from_slice(report_data); - handler - .generate_quote( - occlum_quote.as_mut_ptr(), - &sgx_report_data as *const sgx_report_data_t, - ) - .kind(ErrorKind::SgxDcapAttesterGenerateQuoteFailed) - .context("failed at generate_quote()")?; + let ptr = occlum_quote.as_mut_ptr() as usize; + + #[cfg(not(feature = "is-sync"))] + { + task::spawn_blocking(move || { + handler + .generate_quote( + ptr as *mut u8, + &sgx_report_data as *const sgx_report_data_t, + ) + .kind(ErrorKind::SgxDcapAttesterGenerateQuoteFailed) + .context("failed at generate_quote()"); + }) + .await?; + } + + #[cfg(feature = "is-sync")] + { + handler + .generate_quote(ptr as *mut u8, &sgx_report_data as *const sgx_report_data_t) + .kind(ErrorKind::SgxDcapAttesterGenerateQuoteFailed) + .context("failed at generate_quote()"); + } SgxDcapEvidence::new_from_checked(occlum_quote) } else { diff --git a/rats-rs/src/tee/sgx_dcap/mod.rs b/rats-rs/src/tee/sgx_dcap/mod.rs index 7bb0ee0..907b2de 100644 --- a/rats-rs/src/tee/sgx_dcap/mod.rs +++ b/rats-rs/src/tee/sgx_dcap/mod.rs @@ -29,8 +29,10 @@ pub mod tests { verifier::SgxDcapVerifier, }; - #[test] - fn test_attester_and_verifier() -> Result<()> { + #[cfg_attr(feature = "is-sync", test)] + #[cfg_attr(not(feature = "is-sync"), tokio::test)] + #[maybe_async::maybe_async] + async fn test_attester_and_verifier() -> Result<()> { if TeeType::detect_env() != Some(TeeType::SgxDcap) { /* skip */ return Ok(()); @@ -38,7 +40,7 @@ pub mod tests { let report_data = b"test_report_data"; let attester = SgxDcapAttester::new(); - let evidence = attester.get_evidence(report_data)?; + let evidence = attester.get_evidence(report_data).await?; assert_eq!(evidence.get_tee_type(), TeeType::SgxDcap); let verifier = SgxDcapVerifier::new(); assert_eq!(verifier.verify_evidence(&evidence, report_data), Ok(())); @@ -51,7 +53,6 @@ pub mod tests { assert!(claims.contains_key(BUILT_IN_CLAIM_SGX_MR_ENCLAVE)); assert!(claims.contains_key(BUILT_IN_CLAIM_SGX_MR_SIGNER)); - assert_eq!( claims.get(BUILT_IN_CLAIM_COMMON_QUOTE_TYPE), Some(&"sgx_dcap".as_bytes().into()) diff --git a/rats-rs/src/tee/tdx/attester.rs b/rats-rs/src/tee/tdx/attester.rs index eee76a1..0d0c04b 100644 --- a/rats-rs/src/tee/tdx/attester.rs +++ b/rats-rs/src/tee/tdx/attester.rs @@ -2,6 +2,9 @@ use super::evidence::TdxEvidence; use crate::errors::*; use crate::tee::GenericAttester; +#[cfg(feature = "async-tokio")] +use tokio::task; + pub struct TdxAttester {} impl TdxAttester { @@ -10,10 +13,11 @@ impl TdxAttester { } } +#[maybe_async::maybe_async] impl GenericAttester for TdxAttester { type Evidence = TdxEvidence; - fn get_evidence(&self, report_data: &[u8]) -> Result { + async fn get_evidence(&self, report_data: &[u8]) -> Result { if report_data.len() > 64 { Err(Error::kind_with_msg( ErrorKind::InvalidParameter, @@ -25,13 +29,32 @@ impl GenericAttester for TdxAttester { tdx_report_data.d[..report_data.len()].clone_from_slice(report_data); let mut selected_att_key_id = tdx_attest_rs::tdx_uuid_t { d: [0; 16usize] }; + let result; + let quote; + + #[cfg(not(feature = "is-sync"))] + { + (result, quote) = task::spawn_blocking(move || { + tdx_attest_rs::tdx_att_get_quote( + Some(&tdx_report_data), + None, + Some(&mut selected_att_key_id), + 0, + ) + }) + .await + .expect("Failed to execute blocking operation"); + } - let (result, quote) = tdx_attest_rs::tdx_att_get_quote( - Some(&tdx_report_data), - None, - Some(&mut selected_att_key_id), - 0, - ); + #[cfg(feature = "is-sync")] + { + (result, quote) = tdx_attest_rs::tdx_att_get_quote( + Some(&tdx_report_data), + None, + Some(&mut selected_att_key_id), + 0, + ); + } if result != tdx_attest_rs::tdx_attest_error_t::TDX_ATTEST_SUCCESS { Err(Error::kind_with_msg( diff --git a/rats-rs/src/tee/tdx/mod.rs b/rats-rs/src/tee/tdx/mod.rs index f6ba188..397108f 100644 --- a/rats-rs/src/tee/tdx/mod.rs +++ b/rats-rs/src/tee/tdx/mod.rs @@ -37,8 +37,10 @@ pub mod tests { verifier::TdxVerifier, }; - #[test] - fn test_attester_and_verifier() -> Result<()> { + #[cfg_attr(feature = "is-sync", test)] + #[cfg_attr(not(feature = "is-sync"), tokio::test)] + #[maybe_async::maybe_async] + async fn test_attester_and_verifier() -> Result<()> { if TeeType::detect_env() != Some(TeeType::Tdx) { /* skip */ return Ok(()); @@ -46,7 +48,7 @@ pub mod tests { let report_data = b"test_report_data"; let attester = TdxAttester::new(); - let evidence = attester.get_evidence(report_data)?; + let evidence = attester.get_evidence(report_data).await?; assert_eq!(evidence.get_tee_type(), TeeType::Tdx); let verifier = TdxVerifier::new(); assert_eq!(verifier.verify_evidence(&evidence, report_data), Ok(())); diff --git a/rats-rs/src/transport/spdm/requester.rs b/rats-rs/src/transport/spdm/requester.rs index 3760acd..0d9edde 100644 --- a/rats-rs/src/transport/spdm/requester.rs +++ b/rats-rs/src/transport/spdm/requester.rs @@ -74,13 +74,15 @@ impl SpdmRequesterBuilder { self } - pub fn build_with_tcp_stream(&self, stream: TcpStream) -> Result { + #[maybe_async::maybe_async] + pub async fn build_with_tcp_stream(&self, stream: TcpStream) -> Result { let (cert_provider, asym_signer, measurement_provider) = if cfg!(feature = "mut-auth") && self.attest_self { let attester = AutoAttester::new(); - let cert_bundle = - CertBuilder::new(attester, HashAlgo::Sha256).build(AsymmetricAlgo::P256)?; + let cert_bundle = CertBuilder::new(attester, HashAlgo::Sha256) + .build(AsymmetricAlgo::P256) + .await?; ( Box::new(RatsCertProvider::new_der(cert_bundle.cert_to_der()?)) as Box, @@ -382,7 +384,9 @@ pub mod tests { ) } else { let attester = AutoAttester::new(); - let cert_bundle = CertBuilder::new(attester, hash_algo).build(asym_algo)?; + let cert_bundle = CertBuilder::new(attester, hash_algo) + .build(asym_algo) + .await?; ( Box::new(RatsCertProvider::new_der(cert_bundle.cert_to_der()?)) as Box, diff --git a/rats-rs/src/transport/spdm/responder.rs b/rats-rs/src/transport/spdm/responder.rs index 0fc7505..008c0b7 100644 --- a/rats-rs/src/transport/spdm/responder.rs +++ b/rats-rs/src/transport/spdm/responder.rs @@ -63,12 +63,14 @@ impl SpdmResponderBuilder { self } - pub fn build_with_tcp_stream(&self, stream: TcpStream) -> Result { + #[maybe_async::maybe_async] + pub async fn build_with_tcp_stream(&self, stream: TcpStream) -> Result { let (cert_provider, asym_signer, measurement_provider) = if self.attest_self { // TODO: generate cert and key for each handshake and check nonce from user let attester = AutoAttester::new(); - let cert_bundle = - CertBuilder::new(attester, HashAlgo::Sha256).build(AsymmetricAlgo::P256)?; + let cert_bundle = CertBuilder::new(attester, HashAlgo::Sha256) + .build(AsymmetricAlgo::P256) + .await?; ( Box::new(RatsCertProvider::new_der(cert_bundle.cert_to_der()?)) as Box, @@ -354,7 +356,9 @@ pub mod tests { ) } else { let attester = AutoAttester::new(); - let cert_bundle = CertBuilder::new(attester, hash_algo).build(asym_algo)?; + let cert_bundle = CertBuilder::new(attester, hash_algo) + .build(asym_algo) + .await?; ( Box::new(RatsCertProvider::new_der(cert_bundle.cert_to_der()?)) as Box, @@ -411,8 +415,10 @@ pub mod tests { const THREAD_STACK_SIZE: usize = 8 * 1024 * 1024; - #[test] - fn test_spdm_over_tcp() -> Result<()> { + #[cfg_attr(feature = "is-sync", test)] + #[cfg_attr(not(feature = "is-sync"), tokio::test)] + #[maybe_async::maybe_async] + async fn test_spdm_over_tcp() -> Result<()> { let _ = env_logger::builder() .is_test(true) .filter_level(LevelFilter::Trace) diff --git a/rats-rs/src/transport/tls/client.rs b/rats-rs/src/transport/tls/client.rs index 0b5c54a..986522a 100644 --- a/rats-rs/src/transport/tls/client.rs +++ b/rats-rs/src/transport/tls/client.rs @@ -37,6 +37,8 @@ pub struct Client { attest_self: bool, } +unsafe impl Send for Client {} + pub struct TlsClientBuilder { verify: SSL_verify_cb, stream: Option>, @@ -44,7 +46,8 @@ pub struct TlsClientBuilder { } impl TlsClientBuilder { - pub fn build(self) -> Result { + #[maybe_async] + pub async fn build(self) -> Result { ossl_init()?; let ctx = unsafe { SSL_CTX_new(TLS_client_method()) }; if ctx.is_null() { @@ -63,7 +66,8 @@ impl TlsClientBuilder { let privkey = DefaultCrypto::gen_private_key(crate::crypto::AsymmetricAlgo::Rsa2048)?; c.use_privkey(&privkey)?; let cert = CertBuilder::new(AutoAttester::new(), HashAlgo::Sha256) - .build_with_private_key(&privkey)? + .build_with_private_key(&privkey) + .await? .cert_to_der()?; c.use_cert(&cert)?; } @@ -248,6 +252,7 @@ mod tests { GenericSecureTransPortWrite, }, }; + use maybe_async::maybe_async; use openssl_sys::*; use std::{ net::TcpStream, @@ -261,12 +266,15 @@ mod tests { 0 } } - #[test] - fn test_client_shutdown() -> Result<()> { + + #[cfg_attr(feature = "is-sync", test)] + #[cfg_attr(not(feature = "is-sync"), tokio::test)] + #[maybe_async] + async fn test_client_shutdown() -> Result<()> { let mut builder = TlsClientBuilder::new(); builder.stream = Some(Box::new(GetFdDumpImpl)); - let mut c = builder.build()?; - c.shutdown()?; + let mut c = builder.build().await?; + c.shutdown().await?; Ok(()) } @@ -297,29 +305,34 @@ mod tests { now } - #[test] - fn test_client_use_key() -> Result<()> { + #[cfg_attr(feature = "is-sync", test)] + #[cfg_attr(not(feature = "is-sync"), tokio::test)] + #[maybe_async] + async fn test_client_use_key() -> Result<()> { let mut builder = TlsClientBuilder::new(); builder.stream = Some(Box::new(GetFdDumpImpl)); - let mut c = builder.build()?; + let mut c = builder.build().await?; let privkey = DefaultCrypto::gen_private_key(AsymmetricAlgo::Rsa2048)?; let binding = privkey.to_pkcs8_pem()?; let privpem = binding.as_bytes(); c.use_privkey(&privkey)?; let now = ossl_get_privkey(&mut c); assert_eq!(privpem, now); - c.shutdown()?; + c.shutdown().await?; Ok(()) } - #[test] - fn test_client_use_cert() -> Result<()> { + #[cfg_attr(feature = "is-sync", test)] + #[cfg_attr(not(feature = "is-sync"), tokio::test)] + #[maybe_async] + async fn test_client_use_cert() -> Result<()> { let mut builder = TlsClientBuilder::new(); builder.stream = Some(Box::new(GetFdDumpImpl)); - let mut c = builder.build()?; + let mut c = builder.build().await?; let privkey = DefaultCrypto::gen_private_key(AsymmetricAlgo::Rsa2048)?; let cert = CertBuilder::new(AutoAttester::new(), HashAlgo::Sha256) - .build_with_private_key(&privkey)? + .build_with_private_key(&privkey) + .await? .cert_to_der()?; c.use_cert(&cert)?; let raw_cert = unsafe { SSL_CTX_get0_certificate(c.ctx) }; @@ -327,7 +340,7 @@ mod tests { let len = unsafe { i2d_X509(raw_cert, &mut raw_ptr as *mut *mut u8) }; let now = unsafe { slice::from_raw_parts(raw_ptr as *const u8, len as usize).to_vec() }; assert_eq!(cert, now); - c.shutdown()?; + c.shutdown().await?; Ok(()) } } diff --git a/rats-rs/src/transport/tls/server.rs b/rats-rs/src/transport/tls/server.rs index 8af3ba2..65d6325 100644 --- a/rats-rs/src/transport/tls/server.rs +++ b/rats-rs/src/transport/tls/server.rs @@ -32,6 +32,8 @@ pub struct Server { stream: Box, } +unsafe impl Send for Server {} + // TODO: use typestate design pattern? pub struct TlsServerBuilder { verify: SSL_verify_cb, @@ -40,7 +42,8 @@ pub struct TlsServerBuilder { } impl TlsServerBuilder { - pub fn build(self) -> Result { + #[maybe_async] + pub async fn build(self) -> Result { ossl_init()?; let ctx = unsafe { SSL_CTX_new(TLS_server_method()) }; if ctx.is_null() { @@ -65,7 +68,8 @@ impl TlsServerBuilder { let privkey = DefaultCrypto::gen_private_key(crate::crypto::AsymmetricAlgo::Rsa2048)?; s.use_privkey(&privkey)?; let cert = CertBuilder::new(AutoAttester::new(), HashAlgo::Sha256) - .build_with_private_key(&privkey)? + .build_with_private_key(&privkey) + .await? .cert_to_der()?; s.use_cert(&cert)?; Ok(s) @@ -254,8 +258,10 @@ mod tests { }, }; use core::slice; + use maybe_async::maybe_async; use openssl_sys::*; use std::{net::TcpStream, ptr}; + struct GetFdDumpImpl; impl GetFd for GetFdDumpImpl { fn get_fd(&self) -> i32 { @@ -263,12 +269,14 @@ mod tests { } } - #[test] - fn test_server_shutdown() -> Result<()> { + #[cfg_attr(feature = "is-sync", test)] + #[cfg_attr(not(feature = "is-sync"), tokio::test)] + #[maybe_async] + async fn test_server_shutdown() -> Result<()> { let mut builder = TlsServerBuilder::new(); builder.stream = Some(Box::new(GetFdDumpImpl {})); - let mut s = builder.build()?; - s.shutdown()?; + let mut s = builder.build().await?; + s.shutdown().await?; Ok(()) } @@ -299,8 +307,10 @@ mod tests { now } - #[test] - fn test_server_use_key() -> Result<()> { + #[cfg_attr(feature = "is-sync", test)] + #[cfg_attr(not(feature = "is-sync"), tokio::test)] + #[maybe_async] + async fn test_server_use_key() -> Result<()> { ossl_init()?; let ctx = unsafe { SSL_CTX_new(TLS_server_method()) }; if ctx.is_null() { @@ -318,12 +328,14 @@ mod tests { s.use_privkey(&privkey)?; let now = ossl_get_privkey(&mut s); assert_eq!(privpem, now.as_slice()); - s.shutdown()?; + s.shutdown().await?; Ok(()) } - #[test] - fn test_server_use_cert() -> Result<()> { + #[cfg_attr(feature = "is-sync", test)] + #[cfg_attr(not(feature = "is-sync"), tokio::test)] + #[maybe_async] + async fn test_server_use_cert() -> Result<()> { ossl_init()?; let ctx = unsafe { SSL_CTX_new(TLS_server_method()) }; if ctx.is_null() { @@ -337,7 +349,8 @@ mod tests { }; let privkey = DefaultCrypto::gen_private_key(AsymmetricAlgo::Rsa2048)?; let bundle = CertBuilder::new(AutoAttester::new(), HashAlgo::Sha256) - .build_with_private_key(&privkey)?; + .build_with_private_key(&privkey) + .await?; let cert = bundle.cert_to_der()?; println!("cert.pem: {}", bundle.cert_to_pem()?); s.use_cert(&cert)?; @@ -346,7 +359,7 @@ mod tests { let len = unsafe { i2d_X509(raw_cert, &mut raw_ptr as *mut *mut u8) }; let now = unsafe { slice::from_raw_parts(raw_ptr as *const u8, len as usize).to_vec() }; assert_eq!(cert, now); - s.shutdown()?; + s.shutdown().await?; Ok(()) } } From 989c4a1abad75cb626317912fd4cd0b5e2df5fb7 Mon Sep 17 00:00:00 2001 From: csyJoy Date: Wed, 9 Oct 2024 21:16:28 +0800 Subject: [PATCH 2/5] feat(async): async tls support using `tokio-rustls`, crates features refactor 1. add `rustls` module in transport module which use `tokio-rustls` to support async IO by `tokio` 2. refactor current crates' features structure to avoid feature contention --- examples/openssl-hook/Cargo.toml | 2 +- examples/spdm/Cargo.toml | 2 +- examples/tls/Cargo.toml | 2 +- rats-rs/Cargo.toml | 8 +- rats-rs/src/tee/sgx_dcap/attester.rs | 4 +- rats-rs/src/tee/tdx/attester.rs | 4 +- rats-rs/src/transport/mod.rs | 3 + rats-rs/src/transport/rustls/client.rs | 105 ++++++++++++++++++ rats-rs/src/transport/rustls/mod.rs | 147 +++++++++++++++++++++++++ rats-rs/src/transport/rustls/server.rs | 107 ++++++++++++++++++ 10 files changed, 374 insertions(+), 10 deletions(-) create mode 100644 rats-rs/src/transport/rustls/client.rs create mode 100644 rats-rs/src/transport/rustls/mod.rs create mode 100644 rats-rs/src/transport/rustls/server.rs diff --git a/examples/openssl-hook/Cargo.toml b/examples/openssl-hook/Cargo.toml index 1eb3ebf..32114e6 100644 --- a/examples/openssl-hook/Cargo.toml +++ b/examples/openssl-hook/Cargo.toml @@ -15,4 +15,4 @@ env_logger = {workspace = true} libc = "0.2" anyhow = {workspace = true} lazy_static = "1.5" -rats-rs = {path = "../../rats-rs", features = ["is-sync"]} +rats-rs = {path = "../../rats-rs", features = ["transport-tls"]} diff --git a/examples/spdm/Cargo.toml b/examples/spdm/Cargo.toml index 6a49cb9..f542e30 100644 --- a/examples/spdm/Cargo.toml +++ b/examples/spdm/Cargo.toml @@ -7,7 +7,7 @@ edition.workspace = true license.workspace = true [dependencies] -rats-rs = {path = "../../rats-rs", features = ["is-sync"]} +rats-rs = {path = "../../rats-rs", features = ["transport-spdm"]} clap = {version = "4.5.4", features = ["derive"]} env_logger = {workspace = true} rand = {version = "0.8.5", features = ["small_rng"]} diff --git a/examples/tls/Cargo.toml b/examples/tls/Cargo.toml index c806413..69a4edb 100644 --- a/examples/tls/Cargo.toml +++ b/examples/tls/Cargo.toml @@ -6,7 +6,7 @@ readme.workspace = true license.workspace = true [dependencies] -rats-rs = {path = "../../rats-rs", features = ["is-sync"]} +rats-rs = {path = "../../rats-rs", features = ["transport-tls"]} env_logger = {workspace = true} clap = {version = "4.5.4", features = ["derive"]} anyhow = {workspace = true} diff --git a/rats-rs/Cargo.toml b/rats-rs/Cargo.toml index 0d1f3cf..cfb1ed9 100644 --- a/rats-rs/Cargo.toml +++ b/rats-rs/Cargo.toml @@ -41,14 +41,16 @@ intel-dcap = {path = "../intel-dcap", optional = true} openssl-sys = {version = "0.9", optional = true, features = ["bindgen"] } libc = {version = "0.2", optional = true} lazy_static = "1.5.0" +tokio-rustls = { version = "0.26.0", optional = true} [features] async-tokio = ["dep:tokio"] crypto-rustcrypto = ["dep:x509-cert", "dep:sha2", "dep:p256", "dep:rsa", "dep:pkcs8", "dep:const-oid", "dep:signature"] -default = ["crypto-rustcrypto", "transport-spdm", "attester-sgx-dcap-occlum", "verifier-sgx-dcap", "attester-tdx", "verifier-tdx", "is-sync", "transport-tls"] +default = ["crypto-rustcrypto", "attester-sgx-dcap-occlum", "verifier-sgx-dcap", "attester-tdx", "verifier-tdx"] is-sync = ["maybe-async/is_sync", "spdmlib/is_sync"] -transport-spdm = ["dep:spdmlib", "dep:codec", "dep:ring"] -transport-tls = ["dep:openssl-sys", "dep:libc"] +transport-spdm = ["dep:spdmlib", "dep:codec", "dep:ring", "is-sync"] +transport-tls = ["dep:openssl-sys", "dep:libc", "is-sync"] +transport-rustls = ["dep:tokio-rustls", "async-tokio"] attester-sgx-dcap = ["dep:intel-dcap"] attester-sgx-dcap-occlum = ["attester-sgx-dcap", "dep:occlum_dcap"] attester-sgx-dcap-enclave = [] # TODO: plain enclave mode diff --git a/rats-rs/src/tee/sgx_dcap/attester.rs b/rats-rs/src/tee/sgx_dcap/attester.rs index 5d087d2..8a21556 100644 --- a/rats-rs/src/tee/sgx_dcap/attester.rs +++ b/rats-rs/src/tee/sgx_dcap/attester.rs @@ -40,7 +40,7 @@ impl GenericAttester for SgxDcapAttester { let ptr = occlum_quote.as_mut_ptr() as usize; - #[cfg(not(feature = "is-sync"))] + #[cfg(feature = "async-tokio")] { task::spawn_blocking(move || { handler @@ -54,7 +54,7 @@ impl GenericAttester for SgxDcapAttester { .await?; } - #[cfg(feature = "is-sync")] + #[cfg(not(feature = "async-tokio"))] { handler .generate_quote(ptr as *mut u8, &sgx_report_data as *const sgx_report_data_t) diff --git a/rats-rs/src/tee/tdx/attester.rs b/rats-rs/src/tee/tdx/attester.rs index 0d0c04b..f9494e0 100644 --- a/rats-rs/src/tee/tdx/attester.rs +++ b/rats-rs/src/tee/tdx/attester.rs @@ -32,7 +32,7 @@ impl GenericAttester for TdxAttester { let result; let quote; - #[cfg(not(feature = "is-sync"))] + #[cfg(feature = "async-tokio")] { (result, quote) = task::spawn_blocking(move || { tdx_attest_rs::tdx_att_get_quote( @@ -46,7 +46,7 @@ impl GenericAttester for TdxAttester { .expect("Failed to execute blocking operation"); } - #[cfg(feature = "is-sync")] + #[cfg(not(feature = "async-tokio"))] { (result, quote) = tdx_attest_rs::tdx_att_get_quote( Some(&tdx_report_data), diff --git a/rats-rs/src/transport/mod.rs b/rats-rs/src/transport/mod.rs index 0b0d69f..2e94f9e 100644 --- a/rats-rs/src/transport/mod.rs +++ b/rats-rs/src/transport/mod.rs @@ -4,6 +4,9 @@ pub mod spdm; #[cfg(feature = "transport-tls")] pub mod tls; +#[cfg(feature = "transport-rustls")] +pub mod rustls; + use crate::errors::*; use maybe_async::maybe_async; diff --git a/rats-rs/src/transport/rustls/client.rs b/rats-rs/src/transport/rustls/client.rs new file mode 100644 index 0000000..58d44d4 --- /dev/null +++ b/rats-rs/src/transport/rustls/client.rs @@ -0,0 +1,105 @@ +use super::RatsServerVerifier; +use crate::cert::create::CertBuilder; +use crate::crypto::{DefaultCrypto, HashAlgo}; +use crate::errors::Result; +use crate::tee::AutoAttester; +use crate::transport::{ + GenericSecureTransPort, GenericSecureTransPortRead, GenericSecureTransPortWrite, +}; +use maybe_async::maybe_async; +use std::net::{IpAddr, SocketAddr, ToSocketAddrs}; +use std::sync::Arc; +use tokio::io::{split, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf}; +use tokio::net::TcpStream; +use tokio_rustls::client::TlsStream; +use tokio_rustls::rustls::client::WebPkiServerVerifier; +use tokio_rustls::rustls::pki_types::PrivatePkcs8KeyDer; +use tokio_rustls::rustls::server::WebPkiClientVerifier; +use tokio_rustls::rustls::{self, pki_types, ClientConfig}; +use tokio_rustls::TlsConnector; + +#[allow(unused)] +pub struct RustlsClient { + connector: TlsConnector, + addr: String, + reader: Option>>, + writer: Option>>, +} + +impl RustlsClient { + #[maybe_async] + pub async fn new(addr: &str, mutal: bool) -> Result { + let config_builder = rustls::ClientConfig::builder() + .with_root_certificates(Arc::new(rustls::RootCertStore::empty())); + let mut config; + if mutal { + let privkey = DefaultCrypto::gen_private_key(crate::crypto::AsymmetricAlgo::Rsa2048)?; + let cert = CertBuilder::new(AutoAttester::new(), HashAlgo::Sha256) + .build_with_private_key(&privkey) + .await? + .cert_to_der()?; + let tmp: PrivatePkcs8KeyDer = privkey.to_pkcs8_der()?.as_bytes().to_vec().into(); + config = config_builder.with_client_auth_cert(vec![cert.into()], tmp.into())?; + } else { + config = config_builder.with_no_client_auth(); + } + + config + .dangerous() + .set_certificate_verifier(Arc::new(RatsServerVerifier { + default_server_verifier: WebPkiServerVerifier::builder(Arc::new({ + //XXX: only to bypass empty test of WebPkiServerVerifier + let mut root = rustls::RootCertStore::empty(); + let privkey = + DefaultCrypto::gen_private_key(crate::crypto::AsymmetricAlgo::Rsa2048)?; + let cert = CertBuilder::new(AutoAttester::new(), HashAlgo::Sha256) + .build_with_private_key(&privkey) + .await? + .cert_to_der()?; + root.add(cert.into())?; + root + })) + .build()?, + })); + Ok(RustlsClient { + connector: TlsConnector::from(Arc::new(config)), + addr: addr.to_string(), + reader: None, + writer: None, + }) + } +} + +#[maybe_async] +impl GenericSecureTransPort for RustlsClient { + async fn negotiate(&mut self) -> Result<()> { + let addr = self.addr.parse::()?.ip(); + let stream = TcpStream::connect(&self.addr).await?; + let domain = pki_types::ServerName::try_from(addr)?; + let tls_stream = self.connector.connect(domain, stream).await?; + let (reader, writer) = split(tls_stream); + self.reader = Some(reader); + self.writer = Some(writer); + Ok(()) + } +} + +#[maybe_async] +impl GenericSecureTransPortWrite for RustlsClient { + async fn send(&mut self, bytes: &[u8]) -> Result<()> { + self.writer.as_mut().unwrap().write(bytes).await?; + Ok(()) + } + async fn shutdown(&mut self) -> Result<()> { + self.writer.as_mut().unwrap().shutdown().await?; + Ok(()) + } +} + +#[maybe_async] +impl GenericSecureTransPortRead for RustlsClient { + async fn receive(&mut self, buf: &mut [u8]) -> Result { + let len = self.reader.as_mut().unwrap().read(buf).await?; + Ok(len) + } +} diff --git a/rats-rs/src/transport/rustls/mod.rs b/rats-rs/src/transport/rustls/mod.rs new file mode 100644 index 0000000..87f8f3f --- /dev/null +++ b/rats-rs/src/transport/rustls/mod.rs @@ -0,0 +1,147 @@ +use crate::cert::verify::CertVerifier; +use crate::cert::verify::VerifiyPolicy::Contains; +use crate::cert::verify::VerifyPolicyOutput; +use crate::tee::claims::Claims; +use std::sync::Arc; +use tokio_rustls::rustls::client::danger::HandshakeSignatureValid; +use tokio_rustls::rustls::client::danger::ServerCertVerified; +use tokio_rustls::rustls::server::danger::ClientCertVerified; +use tokio_rustls::rustls::server::ParsedCertificate; +use tokio_rustls::rustls::CertificateError; +use tokio_rustls::rustls::Error; +use tokio_rustls::rustls::{ + client::{danger::ServerCertVerifier, WebPkiServerVerifier}, + server::{danger::ClientCertVerifier, WebPkiClientVerifier}, +}; + +pub mod client; +pub mod server; + +pub use client::RustlsClient; +pub use server::RustlsServer; + +#[derive(Debug)] +struct RatsClientVerifier { + default_client_verifier: Arc, +} + +#[derive(Debug)] +struct RatsServerVerifier { + default_server_verifier: Arc, +} + +impl ClientCertVerifier for RatsClientVerifier { + fn root_hint_subjects(&self) -> &[tokio_rustls::rustls::DistinguishedName] { + self.default_client_verifier.root_hint_subjects() + } + + fn verify_client_cert( + &self, + end_entity: &tokio_rustls::rustls::pki_types::CertificateDer<'_>, + _intermediates: &[tokio_rustls::rustls::pki_types::CertificateDer<'_>], + _now: tokio_rustls::rustls::pki_types::UnixTime, + ) -> Result { + let res = CertVerifier::new(Contains(Claims::new())).verify(&end_entity); + match res { + Ok(VerifyPolicyOutput::Passed) => { + return Ok(ClientCertVerified::assertion()); + } + Ok(VerifyPolicyOutput::Failed) => { + return Err(Error::General( + "Verify failed because of claims".to_string(), + )); + } + Err(err) => { + return Err(Error::General( + format!("Verify failed with err: {:?}", err).to_string(), + )); + } + } + } + + fn verify_tls12_signature( + &self, + message: &[u8], + cert: &tokio_rustls::rustls::pki_types::CertificateDer<'_>, + dss: &tokio_rustls::rustls::DigitallySignedStruct, + ) -> Result { + self.default_client_verifier + .verify_tls12_signature(message, cert, dss) + } + + fn verify_tls13_signature( + &self, + message: &[u8], + cert: &tokio_rustls::rustls::pki_types::CertificateDer<'_>, + dss: &tokio_rustls::rustls::DigitallySignedStruct, + ) -> Result { + self.default_client_verifier + .verify_tls13_signature(message, cert, dss) + } + + fn supported_verify_schemes(&self) -> Vec { + self.default_client_verifier.supported_verify_schemes() + } +} + +impl ServerCertVerifier for RatsServerVerifier { + fn verify_server_cert( + &self, + end_entity: &tokio_rustls::rustls::pki_types::CertificateDer<'_>, + _intermediates: &[tokio_rustls::rustls::pki_types::CertificateDer<'_>], + _server_name: &tokio_rustls::rustls::pki_types::ServerName<'_>, + _ocsp_response: &[u8], + _now: tokio_rustls::rustls::pki_types::UnixTime, + ) -> Result + { + let res = CertVerifier::new(Contains(Claims::new())).verify(&end_entity); + match res { + Ok(VerifyPolicyOutput::Passed) => { + return Ok(ServerCertVerified::assertion()); + } + Ok(VerifyPolicyOutput::Failed) => { + return Err(Error::General( + "Verify failed because of claims".to_string(), + )); + } + Err(err) => { + return Err(Error::General( + format!("Verify failed with err: {:?}", err).to_string(), + )); + } + } + } + + fn verify_tls12_signature( + &self, + message: &[u8], + cert: &tokio_rustls::rustls::pki_types::CertificateDer<'_>, + dss: &tokio_rustls::rustls::DigitallySignedStruct, + ) -> Result< + tokio_rustls::rustls::client::danger::HandshakeSignatureValid, + tokio_rustls::rustls::Error, + > { + self.default_server_verifier + .verify_tls12_signature(message, cert, dss) + } + + fn verify_tls13_signature( + &self, + message: &[u8], + cert: &tokio_rustls::rustls::pki_types::CertificateDer<'_>, + dss: &tokio_rustls::rustls::DigitallySignedStruct, + ) -> Result< + tokio_rustls::rustls::client::danger::HandshakeSignatureValid, + tokio_rustls::rustls::Error, + > { + self.default_server_verifier + .verify_tls13_signature(message, cert, dss) + } + + fn supported_verify_schemes(&self) -> Vec { + self.default_server_verifier.supported_verify_schemes() + } +} + +#[cfg(test)] +mod test {} diff --git a/rats-rs/src/transport/rustls/server.rs b/rats-rs/src/transport/rustls/server.rs new file mode 100644 index 0000000..df18708 --- /dev/null +++ b/rats-rs/src/transport/rustls/server.rs @@ -0,0 +1,107 @@ +use super::RatsClientVerifier; +use crate::cert::create::CertBuilder; +use crate::crypto::{DefaultCrypto, HashAlgo}; +use crate::errors::Result; +use crate::tee::AutoAttester; +use crate::transport::{ + GenericSecureTransPort, GenericSecureTransPortRead, GenericSecureTransPortWrite, +}; +use maybe_async::maybe_async; +use std::mem; +use std::sync::Arc; +use tokio::io::{split, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf}; +use tokio::net::{TcpListener, TcpStream}; +use tokio_rustls::rustls::crypto::CryptoProvider; +use tokio_rustls::rustls::pki_types::PrivatePkcs8KeyDer; +use tokio_rustls::rustls::server::WebPkiClientVerifier; +use tokio_rustls::rustls::{self, ServerConfig}; +use tokio_rustls::server::TlsStream; +use tokio_rustls::TlsAcceptor; + +struct NegotiateInner { + reader: ReadHalf>, + writer: WriteHalf>, +} + +pub struct RustlsServer { + acceptor: TlsAcceptor, + stream: Option, + inner: Option, +} + +impl RustlsServer { + #[maybe_async] + pub async fn new(stream: TcpStream, mutal: bool) -> Result { + let privkey = DefaultCrypto::gen_private_key(crate::crypto::AsymmetricAlgo::Rsa2048)?; + let cert = CertBuilder::new(AutoAttester::new(), HashAlgo::Sha256) + .build_with_private_key(&privkey) + .await? + .cert_to_der()?; + let tmp: PrivatePkcs8KeyDer = privkey.to_pkcs8_der()?.as_bytes().to_vec().into(); + let config_builder = rustls::ServerConfig::builder(); + let config; + if mutal { + config = config_builder + .with_client_cert_verifier(Arc::new(RatsClientVerifier { + default_client_verifier: WebPkiClientVerifier::builder(Arc::new({ + //XXX: only to bypass empty test of WebPkiClientVerifier + let mut root = rustls::RootCertStore::empty(); + let privkey = + DefaultCrypto::gen_private_key(crate::crypto::AsymmetricAlgo::Rsa2048)?; + let cert = CertBuilder::new(AutoAttester::new(), HashAlgo::Sha256) + .build_with_private_key(&privkey) + .await? + .cert_to_der()?; + root.add(cert.into())?; + root + })) + .build()?, + })) + .with_single_cert(vec![cert.into()], tmp.into())?; + } else { + config = config_builder + .with_no_client_auth() + .with_single_cert(vec![cert.into()], tmp.into())?; + } + Ok(RustlsServer { + acceptor: TlsAcceptor::from(Arc::new(config)), + stream: Some(stream), + inner: None, + }) + } +} + +#[maybe_async] +impl GenericSecureTransPort for RustlsServer { + async fn negotiate(&mut self) -> Result<()> { + let acceptor = self.acceptor.clone(); + let stream = std::mem::replace(&mut self.stream, None).unwrap(); + let tls_stream = acceptor.accept(stream).await?; + let (reader, writer) = split(tls_stream); + self.inner = Some(NegotiateInner { + reader: reader, + writer: writer, + }); + Ok(()) + } +} + +#[maybe_async] +impl GenericSecureTransPortWrite for RustlsServer { + async fn send(&mut self, bytes: &[u8]) -> Result<()> { + self.inner.as_mut().unwrap().writer.write(bytes).await?; + Ok(()) + } + async fn shutdown(&mut self) -> Result<()> { + self.inner.as_mut().unwrap().writer.shutdown().await?; + Ok(()) + } +} + +#[maybe_async] +impl GenericSecureTransPortRead for RustlsServer { + async fn receive(&mut self, buf: &mut [u8]) -> Result { + let len = self.inner.as_mut().unwrap().reader.read(buf).await?; + Ok(len) + } +} From 4da759357dc4f6ba425e2a587f498fac4b2bb6b9 Mon Sep 17 00:00:00 2001 From: csyJoy Date: Wed, 9 Oct 2024 21:20:31 +0800 Subject: [PATCH 3/5] example(async): add async tls echo server --- examples/rustls/Cargo.toml | 15 ++++++ examples/rustls/src/echo/mod.rs | 50 ++++++++++++++++++ examples/rustls/src/main.rs | 89 +++++++++++++++++++++++++++++++++ examples/rustls/src/rustls.rs | 67 +++++++++++++++++++++++++ 4 files changed, 221 insertions(+) create mode 100644 examples/rustls/Cargo.toml create mode 100644 examples/rustls/src/echo/mod.rs create mode 100644 examples/rustls/src/main.rs create mode 100644 examples/rustls/src/rustls.rs diff --git a/examples/rustls/Cargo.toml b/examples/rustls/Cargo.toml new file mode 100644 index 0000000..c5e5cc7 --- /dev/null +++ b/examples/rustls/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "rustls" +edition.workspace = true +version.workspace = true +readme.workspace = true +license.workspace = true + +[dependencies] +rats-rs = {path = "../../rats-rs", features = ["transport-rustls"]} +env_logger = {workspace = true} +clap = {version = "4.5.4", features = ["derive"]} +anyhow = {workspace = true} +log = {workspace = true} +rand = {version = "0.8.5", features = ["small_rng"]} +tokio = {version = "1.36.0", features = ["full"]} \ No newline at end of file diff --git a/examples/rustls/src/echo/mod.rs b/examples/rustls/src/echo/mod.rs new file mode 100644 index 0000000..c7bac8f --- /dev/null +++ b/examples/rustls/src/echo/mod.rs @@ -0,0 +1,50 @@ +use crate::{ + rustls::{with_tls_tcp_client, with_tls_tcp_server}, + CommonClientOptions, CommonServerOptions, +}; +use anyhow::Result; +use log::info; +use rand::{rngs::SmallRng, RngCore, SeedableRng}; +use rats_rs::transport::{ + rustls::{RustlsClient, RustlsServer}, + GenericSecureTransPortRead, GenericSecureTransPortWrite, +}; + +pub async fn echo_client(opts: CommonClientOptions) -> Result<()> { + with_tls_tcp_client(opts, async |mut c: RustlsClient| { + let mut rng = SmallRng::from_entropy(); + for _i in 0..128 { + let mut expected = [0u8; 8]; + let expected_len = expected.len(); + + rng.fill_bytes(&mut expected); + c.send(&expected).await?; + + let mut buffer = [0u8; 1024]; + let recv_len = c.receive(&mut buffer[..expected_len]).await?; + + assert_eq!(expected_len, recv_len); + assert_eq!(expected, buffer[..expected_len]); + info!("{}/128: passed", _i + 1); + } + c.shutdown().await?; + Ok(()) + }) + .await?; + Ok(()) +} + +pub async fn echo_server(opts: CommonServerOptions) -> Result<()> { + with_tls_tcp_server(opts, async |mut s: RustlsServer| { + let mut buffer = [0u8; 1024]; + let buffer_len = buffer.len(); + for _i in 0..128 { + let recv_len = s.receive(&mut buffer[..buffer_len]).await?; + s.send(&mut buffer[..recv_len]).await?; + } + s.shutdown().await?; + Ok(()) + }) + .await?; + Ok(()) +} diff --git a/examples/rustls/src/main.rs b/examples/rustls/src/main.rs new file mode 100644 index 0000000..0601fbc --- /dev/null +++ b/examples/rustls/src/main.rs @@ -0,0 +1,89 @@ +#![feature(async_closure)] + +mod echo; +mod rustls; + +use anyhow::Result; +use clap::{arg, ArgAction, Parser}; +use echo::{echo_client, echo_server}; + +#[derive(Parser, Debug)] +#[command(version, about, long_about = None)] +enum TlsCommand { + #[command(name = "echo-client")] + EchoClient(CommonClientOptions), + #[command(name = "echo-server")] + EchoServer(CommonServerOptions), +} + +#[derive(Parser, Debug)] +struct CommonServerOptions { + /// Whether to attest self to peer. Defaults to `true`. + #[arg( + long, + default_missing_value("true"), + default_value("true"), + num_args(0..=1), + require_equals(true), + action = ArgAction::Set, + )] + attest_self: bool, + + /// Whether to verify peer. Defaults to `false`. + #[arg( + long, + default_missing_value("true"), + default_value("false"), + num_args(0..=1), + require_equals(true), + action = ArgAction::Set, + )] + verify_peer: bool, + + /// The ip:port to listen on for TCP connections. + #[arg(long)] + listen_on_tcp: String, +} + +#[derive(Parser, Debug, Clone)] +struct CommonClientOptions { + /// Whether to attest self to peer. Defaults to `false`. + #[arg( + long, + default_missing_value("true"), + default_value("false"), + num_args(0..=1), + require_equals(true), + action = ArgAction::Set, + )] + attest_self: bool, + + /// Whether to verify peer. Defaults to `true`. + #[arg( + long, + default_missing_value("true"), + default_value("true"), + num_args(0..=1), + require_equals(true), + action = ArgAction::Set, + )] + verify_peer: bool, + + /// The ip:port to connect to for TCP connection. + #[arg(long)] + connect_to_tcp: String, +} + +#[tokio::main] +async fn main() -> Result<()> { + let env = env_logger::Env::default() + .filter_or("RATS_RS_LOG_LEVEL", "debug") + .write_style_or("RATS_RS_LOG_STYLE", "always"); + env_logger::Builder::from_env(env).init(); + let command = TlsCommand::parse(); + match command { + TlsCommand::EchoClient(opts) => echo_client(opts).await?, + TlsCommand::EchoServer(opts) => echo_server(opts).await?, + } + Ok(()) +} diff --git a/examples/rustls/src/rustls.rs b/examples/rustls/src/rustls.rs new file mode 100644 index 0000000..519b96f --- /dev/null +++ b/examples/rustls/src/rustls.rs @@ -0,0 +1,67 @@ +use std::future::Future; + +use crate::{CommonClientOptions, CommonServerOptions}; +use anyhow::{bail, Result}; +use log::info; +use rats_rs::transport::rustls::{RustlsClient, RustlsServer}; +use rats_rs::transport::{ + GenericSecureTransPort, GenericSecureTransPortRead, GenericSecureTransPortWrite, +}; +use tokio::net::{TcpListener, TcpStream, ToSocketAddrs}; + +pub async fn with_tls_tcp_client( + opts: CommonClientOptions, + func: impl FnOnce(RustlsClient) -> T, +) -> Result<()> +where + T: Future>, +{ + let mut rustls_client = RustlsClient::new(&opts.connect_to_tcp, opts.attest_self).await?; + + info!("Connected to server: {}", opts.connect_to_tcp); + + rustls_client.negotiate().await?; + + info!( + "The tls session on connection {} (responder) is ready.", + opts.connect_to_tcp, + ); + + func(rustls_client).await?; + + info!("Everything is fine, exit now."); + Ok(()) +} + +pub async fn with_tls_tcp_server( + opts: CommonServerOptions, + func: impl Fn(RustlsServer) -> T, +) -> Result<()> +where + T: Future>, +{ + let listener = TcpListener::bind(&opts.listen_on_tcp) + .await + .expect("Failed to bind to address"); + + info!("Server started, listening on {}", &opts.listen_on_tcp); + + loop { + let (stream, peer_addr) = listener.accept().await?; + info!("New connection: {}", peer_addr); + let mut server = RustlsServer::new(stream, opts.verify_peer).await?; + server.negotiate().await?; + + info!( + "The tls session on connection {} (requester) is ready.", + peer_addr + ); + + func(server).await?; + + info!( + "The connection {} is shutdown, waiting for another now.", + peer_addr + ); + } +} From 9f463bfb846ca346bd25e89318d2b604bb6e2396 Mon Sep 17 00:00:00 2001 From: csyJoy Date: Thu, 10 Oct 2024 08:53:24 +0800 Subject: [PATCH 4/5] fix(test): resolve build failed error during cargo test 1. adjust the position of `#[maybe_async]` for test cases. 2. remove `example/*` from default-members in `Cargo.toml` which will cause default `cargo build` fail. 3. use `-p` instead of `--bin` while building an example. 4. adding `base` feature in `rats-rs`, making it convenient to re-add the basic features after disabling the default features. --- .github/workflows/build-and-test.yaml | 2 +- Cargo.toml | 1 - README.md | 2 +- examples/rustls/Cargo.toml | 2 +- rats-rs/Cargo.toml | 3 ++- rats-rs/src/cert/create.rs | 3 ++- rats-rs/src/cert/verify.rs | 15 +++++++-------- rats-rs/src/tee/mod.rs | 7 +++---- rats-rs/src/tee/sgx_dcap/mod.rs | 3 ++- rats-rs/src/tee/tdx/mod.rs | 3 ++- rats-rs/src/transport/spdm/responder.rs | 2 +- rats-rs/src/transport/tls/client.rs | 8 ++++---- rats-rs/src/transport/tls/server.rs | 6 +++--- 13 files changed, 29 insertions(+), 28 deletions(-) diff --git a/.github/workflows/build-and-test.yaml b/.github/workflows/build-and-test.yaml index 7876dab..9f72b6a 100644 --- a/.github/workflows/build-and-test.yaml +++ b/.github/workflows/build-and-test.yaml @@ -27,7 +27,7 @@ jobs: - name: Compile ${{ github.repository }} run: - cargo build && cargo build --bin spdm + cargo build && cargo build -p spdm env: HOME: /root diff --git a/Cargo.toml b/Cargo.toml index 1aeab04..f736e3c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,6 @@ members = [ ] default-members = [ "rats-rs", - "examples/*", ] resolver = "2" diff --git a/README.md b/README.md index cf71a01..5e9b45d 100644 --- a/README.md +++ b/README.md @@ -45,7 +45,7 @@ rats-rs是一个纯Rust实现的TEE远程证明库,它的最终目标是让开 just prepare-repo - cargo build --bin spdm + cargo build -p spdm ``` 3. 运行Server端程序 diff --git a/examples/rustls/Cargo.toml b/examples/rustls/Cargo.toml index c5e5cc7..5c47f8c 100644 --- a/examples/rustls/Cargo.toml +++ b/examples/rustls/Cargo.toml @@ -6,7 +6,7 @@ readme.workspace = true license.workspace = true [dependencies] -rats-rs = {path = "../../rats-rs", features = ["transport-rustls"]} +rats-rs = {path = "../../rats-rs", default-features = false, features = ["base", "transport-rustls"]} env_logger = {workspace = true} clap = {version = "4.5.4", features = ["derive"]} anyhow = {workspace = true} diff --git a/rats-rs/Cargo.toml b/rats-rs/Cargo.toml index cfb1ed9..c022911 100644 --- a/rats-rs/Cargo.toml +++ b/rats-rs/Cargo.toml @@ -46,7 +46,8 @@ tokio-rustls = { version = "0.26.0", optional = true} [features] async-tokio = ["dep:tokio"] crypto-rustcrypto = ["dep:x509-cert", "dep:sha2", "dep:p256", "dep:rsa", "dep:pkcs8", "dep:const-oid", "dep:signature"] -default = ["crypto-rustcrypto", "attester-sgx-dcap-occlum", "verifier-sgx-dcap", "attester-tdx", "verifier-tdx"] +base = ["crypto-rustcrypto", "attester-sgx-dcap-occlum", "verifier-sgx-dcap", "attester-tdx", "verifier-tdx"] +default = ["base", "is-sync"] is-sync = ["maybe-async/is_sync", "spdmlib/is_sync"] transport-spdm = ["dep:spdmlib", "dep:codec", "dep:ring", "is-sync"] transport-tls = ["dep:openssl-sys", "dep:libc", "is-sync"] diff --git a/rats-rs/src/cert/create.rs b/rats-rs/src/cert/create.rs index 9efdccf..500dbde 100644 --- a/rats-rs/src/cert/create.rs +++ b/rats-rs/src/cert/create.rs @@ -145,13 +145,14 @@ pub mod tests { errors::*, tee::{claims::Claims, AutoAttester, TeeType}, }; + use maybe_async::maybe_async; #[allow(unused_imports)] use super::*; + #[maybe_async] #[cfg_attr(feature = "is-sync", test)] #[cfg_attr(not(feature = "is-sync"), tokio::test)] - #[maybe_async::maybe_async] async fn test_get_attestation_certificate() -> Result<()> { if TeeType::detect_env() == None { /* skip */ diff --git a/rats-rs/src/cert/verify.rs b/rats-rs/src/cert/verify.rs index e73d723..b69c844 100644 --- a/rats-rs/src/cert/verify.rs +++ b/rats-rs/src/cert/verify.rs @@ -219,8 +219,8 @@ fn extract_ext_with_oid<'a>(cert: &'a Certificate, oid: &ObjectIdentifier) -> Op #[cfg(test)] pub mod tests { - use indexmap::IndexMap; - + #[allow(unused_imports)] + use super::*; use crate::{ cert::create::CertBuilder, crypto::{AsymmetricAlgo, DefaultCrypto, HashAlgo}, @@ -230,13 +230,12 @@ pub mod tests { AutoAttester, TeeType, }, }; + use indexmap::IndexMap; + use maybe_async::maybe_async; - #[allow(unused_imports)] - use super::*; - + #[maybe_async] #[cfg_attr(feature = "is-sync", test)] #[cfg_attr(not(feature = "is-sync"), tokio::test)] - #[maybe_async::maybe_async] async fn test_verify_cert_der() -> Result<()> { if TeeType::detect_env() == None { /* skip */ @@ -266,9 +265,9 @@ pub mod tests { Ok(()) } + #[maybe_async] #[cfg_attr(feature = "is-sync", test)] #[cfg_attr(not(feature = "is-sync"), tokio::test)] - #[maybe_async::maybe_async] async fn test_verify_attestation_certificate() -> Result<()> { if TeeType::detect_env() == None { /* skip */ @@ -309,9 +308,9 @@ pub mod tests { Ok(()) } + #[maybe_async] #[cfg_attr(feature = "is-sync", test)] #[cfg_attr(not(feature = "is-sync"), tokio::test)] - #[maybe_async::maybe_async] async fn test_verify_attestation_certificate_with_claims_overriding() -> Result<()> { if TeeType::detect_env() == None { /* skip */ diff --git a/rats-rs/src/tee/mod.rs b/rats-rs/src/tee/mod.rs index 28530d6..c170059 100644 --- a/rats-rs/src/tee/mod.rs +++ b/rats-rs/src/tee/mod.rs @@ -212,14 +212,13 @@ impl GenericVerifier for AutoVerifier { #[cfg(test)] pub mod tests { use tests::claims::{BUILT_IN_CLAIM_COMMON_QUOTE, BUILT_IN_CLAIM_COMMON_QUOTE_TYPE}; - use crate::errors::*; - use super::*; + use maybe_async::maybe_async; + #[maybe_async] #[cfg_attr(feature = "is-sync", test)] #[cfg_attr(not(feature = "is-sync"), tokio::test)] - #[maybe_async::maybe_async] async fn test_auto_attester_and_auto_verifier_on_sgx_dcap() -> Result<()> { if TeeType::detect_env() != Some(TeeType::SgxDcap) { /* skip */ @@ -243,9 +242,9 @@ pub mod tests { } + #[maybe_async] #[cfg_attr(feature = "is-sync", test)] #[cfg_attr(not(feature = "is-sync"), tokio::test)] - #[maybe_async::maybe_async] async fn test_auto_attester_and_auto_verifier_on_non_tee() -> Result<()> { if TeeType::detect_env() != None { /* skip */ diff --git a/rats-rs/src/tee/sgx_dcap/mod.rs b/rats-rs/src/tee/sgx_dcap/mod.rs index 907b2de..9b282cf 100644 --- a/rats-rs/src/tee/sgx_dcap/mod.rs +++ b/rats-rs/src/tee/sgx_dcap/mod.rs @@ -23,15 +23,16 @@ pub mod tests { GenericAttester, GenericEvidence, GenericVerifier, TeeType, }, }; + use maybe_async::maybe_async; use tests::{ attester::SgxDcapAttester, claims::{BUILT_IN_CLAIM_SGX_MR_ENCLAVE, BUILT_IN_CLAIM_SGX_MR_SIGNER}, verifier::SgxDcapVerifier, }; + #[maybe_async] #[cfg_attr(feature = "is-sync", test)] #[cfg_attr(not(feature = "is-sync"), tokio::test)] - #[maybe_async::maybe_async] async fn test_attester_and_verifier() -> Result<()> { if TeeType::detect_env() != Some(TeeType::SgxDcap) { /* skip */ diff --git a/rats-rs/src/tee/tdx/mod.rs b/rats-rs/src/tee/tdx/mod.rs index 397108f..28be931 100644 --- a/rats-rs/src/tee/tdx/mod.rs +++ b/rats-rs/src/tee/tdx/mod.rs @@ -28,6 +28,7 @@ pub mod tests { GenericAttester, GenericEvidence, GenericVerifier, TeeType, }, }; + use maybe_async::maybe_async; use tests::{ attester::TdxAttester, claims::{ @@ -37,9 +38,9 @@ pub mod tests { verifier::TdxVerifier, }; + #[maybe_async] #[cfg_attr(feature = "is-sync", test)] #[cfg_attr(not(feature = "is-sync"), tokio::test)] - #[maybe_async::maybe_async] async fn test_attester_and_verifier() -> Result<()> { if TeeType::detect_env() != Some(TeeType::Tdx) { /* skip */ diff --git a/rats-rs/src/transport/spdm/responder.rs b/rats-rs/src/transport/spdm/responder.rs index 008c0b7..34a1cfc 100644 --- a/rats-rs/src/transport/spdm/responder.rs +++ b/rats-rs/src/transport/spdm/responder.rs @@ -415,9 +415,9 @@ pub mod tests { const THREAD_STACK_SIZE: usize = 8 * 1024 * 1024; + #[maybe_async::maybe_async] #[cfg_attr(feature = "is-sync", test)] #[cfg_attr(not(feature = "is-sync"), tokio::test)] - #[maybe_async::maybe_async] async fn test_spdm_over_tcp() -> Result<()> { let _ = env_logger::builder() .is_test(true) diff --git a/rats-rs/src/transport/tls/client.rs b/rats-rs/src/transport/tls/client.rs index 986522a..388e307 100644 --- a/rats-rs/src/transport/tls/client.rs +++ b/rats-rs/src/transport/tls/client.rs @@ -240,7 +240,7 @@ impl Client { } #[cfg(test)] -mod tests { +pub mod tests { use super::{Client, TlsClientBuilder}; use crate::{ cert::create::CertBuilder, @@ -267,9 +267,9 @@ mod tests { } } + #[maybe_async] #[cfg_attr(feature = "is-sync", test)] #[cfg_attr(not(feature = "is-sync"), tokio::test)] - #[maybe_async] async fn test_client_shutdown() -> Result<()> { let mut builder = TlsClientBuilder::new(); builder.stream = Some(Box::new(GetFdDumpImpl)); @@ -305,9 +305,9 @@ mod tests { now } + #[maybe_async] #[cfg_attr(feature = "is-sync", test)] #[cfg_attr(not(feature = "is-sync"), tokio::test)] - #[maybe_async] async fn test_client_use_key() -> Result<()> { let mut builder = TlsClientBuilder::new(); builder.stream = Some(Box::new(GetFdDumpImpl)); @@ -322,9 +322,9 @@ mod tests { Ok(()) } + #[maybe_async] #[cfg_attr(feature = "is-sync", test)] #[cfg_attr(not(feature = "is-sync"), tokio::test)] - #[maybe_async] async fn test_client_use_cert() -> Result<()> { let mut builder = TlsClientBuilder::new(); builder.stream = Some(Box::new(GetFdDumpImpl)); diff --git a/rats-rs/src/transport/tls/server.rs b/rats-rs/src/transport/tls/server.rs index 65d6325..384cb5d 100644 --- a/rats-rs/src/transport/tls/server.rs +++ b/rats-rs/src/transport/tls/server.rs @@ -269,9 +269,9 @@ mod tests { } } + #[maybe_async] #[cfg_attr(feature = "is-sync", test)] #[cfg_attr(not(feature = "is-sync"), tokio::test)] - #[maybe_async] async fn test_server_shutdown() -> Result<()> { let mut builder = TlsServerBuilder::new(); builder.stream = Some(Box::new(GetFdDumpImpl {})); @@ -307,9 +307,9 @@ mod tests { now } + #[maybe_async] #[cfg_attr(feature = "is-sync", test)] #[cfg_attr(not(feature = "is-sync"), tokio::test)] - #[maybe_async] async fn test_server_use_key() -> Result<()> { ossl_init()?; let ctx = unsafe { SSL_CTX_new(TLS_server_method()) }; @@ -332,9 +332,9 @@ mod tests { Ok(()) } + #[maybe_async] #[cfg_attr(feature = "is-sync", test)] #[cfg_attr(not(feature = "is-sync"), tokio::test)] - #[maybe_async] async fn test_server_use_cert() -> Result<()> { ossl_init()?; let ctx = unsafe { SSL_CTX_new(TLS_server_method()) }; From c6bbaf100eb4d4345ce8a4af577908c265c53928 Mon Sep 17 00:00:00 2001 From: csyJoy Date: Sat, 12 Oct 2024 16:26:36 +0800 Subject: [PATCH 5/5] fix(misc): code quality and security improvements, comments added, documentation updated. 1. add missing error handling code. 2. move `Vec` instead of raw ptr while spawning a async block task 3. add comments explaining the reason for implementing `Send` to tls `Client` `Server` 4. document update about cargo building package from `--bin` to `-p` --- docs/how-to-build.md | 4 ++-- examples/spdm/README.md | 2 +- justfile | 4 ++-- rats-rs/Cargo.toml | 4 ++-- rats-rs/src/tee/sgx_dcap/attester.rs | 20 +++++++++++--------- rats-rs/src/transport/tls/client.rs | 3 +++ rats-rs/src/transport/tls/server.rs | 3 +++ 7 files changed, 24 insertions(+), 16 deletions(-) diff --git a/docs/how-to-build.md b/docs/how-to-build.md index 216bd79..476dd47 100644 --- a/docs/how-to-build.md +++ b/docs/how-to-build.md @@ -159,7 +159,7 @@ docker build --tag rats-rs:master . 4. (可选)构建样例程序 ```sh - cargo build --bin spdm + cargo build -p spdm ``` - 对于如何运行样例程序,请参考examples目录下的[例子](/examples/spdm)。 \ No newline at end of file + 对于如何运行样例程序,请参考examples目录下的[例子](/examples/spdm)。 diff --git a/examples/spdm/README.md b/examples/spdm/README.md index 2751472..c9a1029 100644 --- a/examples/spdm/README.md +++ b/examples/spdm/README.md @@ -11,7 +11,7 @@ 接下来,使用如下命令构建本样例程序 ```sh -cargo build --bin spdm +cargo build -p spdm ``` 可以使用`target/debug/spdm --help`命令查看该样例程序的命令行参数 diff --git a/justfile b/justfile index 2ae7686..85758bb 100644 --- a/justfile +++ b/justfile @@ -8,11 +8,11 @@ prepare-repo: cd deps/spdm-rs && sh_script/pre-build.sh run-in-occlum *args: - cargo build --bin spdm + cargo build -p spdm scripts/run_exe_in_occlum.sh target/debug/spdm {{args}} run-in-host *args: - cargo build --bin spdm + cargo build -p spdm target/debug/spdm {{args}} run-test-in-occlum *args: diff --git a/rats-rs/Cargo.toml b/rats-rs/Cargo.toml index c022911..8d2bb74 100644 --- a/rats-rs/Cargo.toml +++ b/rats-rs/Cargo.toml @@ -49,8 +49,8 @@ crypto-rustcrypto = ["dep:x509-cert", "dep:sha2", "dep:p256", "dep:rsa", "dep:pk base = ["crypto-rustcrypto", "attester-sgx-dcap-occlum", "verifier-sgx-dcap", "attester-tdx", "verifier-tdx"] default = ["base", "is-sync"] is-sync = ["maybe-async/is_sync", "spdmlib/is_sync"] -transport-spdm = ["dep:spdmlib", "dep:codec", "dep:ring", "is-sync"] -transport-tls = ["dep:openssl-sys", "dep:libc", "is-sync"] +transport-spdm = ["dep:spdmlib", "dep:codec", "dep:ring"] +transport-tls = ["dep:openssl-sys", "dep:libc"] transport-rustls = ["dep:tokio-rustls", "async-tokio"] attester-sgx-dcap = ["dep:intel-dcap"] attester-sgx-dcap-occlum = ["attester-sgx-dcap", "dep:occlum_dcap"] diff --git a/rats-rs/src/tee/sgx_dcap/attester.rs b/rats-rs/src/tee/sgx_dcap/attester.rs index 8a21556..62c17fa 100644 --- a/rats-rs/src/tee/sgx_dcap/attester.rs +++ b/rats-rs/src/tee/sgx_dcap/attester.rs @@ -38,28 +38,30 @@ impl GenericAttester for SgxDcapAttester { let mut sgx_report_data = sgx_report_data_t::default(); sgx_report_data.d[..report_data.len()].clone_from_slice(report_data); - let ptr = occlum_quote.as_mut_ptr() as usize; - #[cfg(feature = "async-tokio")] { - task::spawn_blocking(move || { + let handle = task::spawn_blocking(move || { handler .generate_quote( - ptr as *mut u8, + occlum_quote.as_mut_ptr(), &sgx_report_data as *const sgx_report_data_t, ) .kind(ErrorKind::SgxDcapAttesterGenerateQuoteFailed) - .context("failed at generate_quote()"); - }) - .await?; + .context("failed at generate_quote()") + .map(|_| occlum_quote) + }); + occlum_quote = handle.await.context("the quote generation task panics")??; } #[cfg(not(feature = "async-tokio"))] { handler - .generate_quote(ptr as *mut u8, &sgx_report_data as *const sgx_report_data_t) + .generate_quote( + occlum_quote.as_mut_ptr(), + &sgx_report_data as *const sgx_report_data_t, + ) .kind(ErrorKind::SgxDcapAttesterGenerateQuoteFailed) - .context("failed at generate_quote()"); + .context("failed at generate_quote()")?; } SgxDcapEvidence::new_from_checked(occlum_quote) diff --git a/rats-rs/src/transport/tls/client.rs b/rats-rs/src/transport/tls/client.rs index 388e307..8aa692e 100644 --- a/rats-rs/src/transport/tls/client.rs +++ b/rats-rs/src/transport/tls/client.rs @@ -37,6 +37,9 @@ pub struct Client { attest_self: bool, } +// `Client` is not 'Send' because it contains raw pointer which doesn't impl `Send` +// async methods capturing `&mut Client` need `Send` trait for `Client`, so we impl here. +#[cfg(feature = "async-tokio")] unsafe impl Send for Client {} pub struct TlsClientBuilder { diff --git a/rats-rs/src/transport/tls/server.rs b/rats-rs/src/transport/tls/server.rs index 384cb5d..d43cf9d 100644 --- a/rats-rs/src/transport/tls/server.rs +++ b/rats-rs/src/transport/tls/server.rs @@ -32,6 +32,9 @@ pub struct Server { stream: Box, } +// `Server` is not 'Send' because it contains raw pointer which doesn't impl `Send` +// async methods capturing `&mut Server` need `Send` trait for `Server`, so we impl here. +#[cfg(feature = "async-tokio")] unsafe impl Send for Server {} // TODO: use typestate design pattern?