diff --git a/mountpoint-s3-client/src/failure_client.rs b/mountpoint-s3-client/src/failure_client.rs index 90d752e8a..264ed119e 100644 --- a/mountpoint-s3-client/src/failure_client.rs +++ b/mountpoint-s3-client/src/failure_client.rs @@ -227,6 +227,12 @@ impl GetObjectReques self.request.get_object_checksum().await } + async fn get_object_sse( + &self, + ) -> ObjectClientResult<(Option, Option), GetObjectError, Self::ClientError> { + Ok((None, None)) + } + fn increment_read_window(self: Pin<&mut Self>, len: usize) { let this = self.project(); this.request.increment_read_window(len); diff --git a/mountpoint-s3-client/src/mock_client.rs b/mountpoint-s3-client/src/mock_client.rs index c2efad642..ae2964a16 100644 --- a/mountpoint-s3-client/src/mock_client.rs +++ b/mountpoint-s3-client/src/mock_client.rs @@ -542,6 +542,12 @@ impl GetObjectRequest for MockGetObjectRequest { Ok(self.object.checksum.clone()) } + async fn get_object_sse( + &self, + ) -> ObjectClientResult<(Option, Option), GetObjectError, Self::ClientError> { + Ok((None, None)) + } + fn increment_read_window(mut self: Pin<&mut Self>, len: usize) { self.read_window_end_offset += len as u64; } diff --git a/mountpoint-s3-client/src/mock_client/throughput_client.rs b/mountpoint-s3-client/src/mock_client/throughput_client.rs index a64d1b31e..83a57dd97 100644 --- a/mountpoint-s3-client/src/mock_client/throughput_client.rs +++ b/mountpoint-s3-client/src/mock_client/throughput_client.rs @@ -78,6 +78,12 @@ impl GetObjectRequest for ThroughputGetObjectRequest { Ok(self.request.object.checksum.clone()) } + async fn get_object_sse( + &self, + ) -> ObjectClientResult<(Option, Option), GetObjectError, Self::ClientError> { + Ok((None, None)) + } + fn increment_read_window(self: Pin<&mut Self>, len: usize) { let this = self.project(); this.request.increment_read_window(len); diff --git a/mountpoint-s3-client/src/object_client.rs b/mountpoint-s3-client/src/object_client.rs index 5240ae952..25e70918f 100644 --- a/mountpoint-s3-client/src/object_client.rs +++ b/mountpoint-s3-client/src/object_client.rs @@ -572,6 +572,12 @@ pub trait GetObjectRequest: /// Get the object's checksum, if uploaded with one async fn get_object_checksum(&self) -> ObjectClientResult; + /// Get the object's SSE type and KMS Key ARN, as defined in: + /// https://docs.aws.amazon.com/kms/latest/developerguide/concepts.html#key-id + async fn get_object_sse( + &self, + ) -> ObjectClientResult<(Option, Option), GetObjectError, Self::ClientError>; + /// Increment the flow-control window, so that response data continues downloading. /// /// If the client was created with `enable_read_backpressure` set true, diff --git a/mountpoint-s3-client/src/s3_crt_client.rs b/mountpoint-s3-client/src/s3_crt_client.rs index ed2b9f236..926dad2ba 100644 --- a/mountpoint-s3-client/src/s3_crt_client.rs +++ b/mountpoint-s3-client/src/s3_crt_client.rs @@ -1134,6 +1134,14 @@ fn parse_checksum(headers: &Headers) -> Result { }) } +/// Extract the SSE-type and SSE-key information from headers +fn parse_object_sse(headers: &Headers) -> Result<(Option, Option), HeadersError> { + let sse_type = headers.get_as_optional_string("x-amz-server-side-encryption")?; + let sse_kms_key_id = headers.get_as_optional_string("x-amz-server-side-encryption-aws-kms-key-id")?; + + Ok((sse_type, sse_kms_key_id)) +} + /// Try to parse a modeled error out of a failing meta request fn try_parse_generic_error(request_result: &MetaRequestResult) -> Option { /// Look for a redirect header pointing to a different region for the bucket diff --git a/mountpoint-s3-client/src/s3_crt_client/get_object.rs b/mountpoint-s3-client/src/s3_crt_client/get_object.rs index f2be3c0df..6604ff66a 100644 --- a/mountpoint-s3-client/src/s3_crt_client/get_object.rs +++ b/mountpoint-s3-client/src/s3_crt_client/get_object.rs @@ -19,7 +19,8 @@ use crate::object_client::{ Checksum, GetBodyPart, GetObjectError, GetObjectParams, ObjectClientError, ObjectClientResult, ObjectMetadata, }; use crate::s3_crt_client::{ - parse_checksum, GetObjectRequest, S3CrtClient, S3CrtClientInner, S3HttpRequest, S3Operation, S3RequestError, + parse_checksum, parse_object_sse, GetObjectRequest, S3CrtClient, S3CrtClientInner, S3HttpRequest, S3Operation, + S3RequestError, }; use crate::types::ChecksumMode; @@ -204,6 +205,14 @@ impl GetObjectRequest for S3GetObjectRequest { parse_checksum(&headers).map_err(|e| ObjectClientError::ClientError(S3RequestError::InternalError(Box::new(e)))) } + async fn get_object_sse( + &self, + ) -> ObjectClientResult<(Option, Option), GetObjectError, Self::ClientError> { + let headers = self.get_object_headers().await?; + parse_object_sse(&headers) + .map_err(|e| ObjectClientError::ClientError(S3RequestError::InternalError(Box::new(e)))) + } + fn increment_read_window(mut self: Pin<&mut Self>, len: usize) { self.read_window_end_offset += len as u64; self.request.meta_request.increment_read_window(len as u64); diff --git a/mountpoint-s3-client/tests/common/mod.rs b/mountpoint-s3-client/tests/common/mod.rs index 6dcb2965a..e5a915c6b 100644 --- a/mountpoint-s3-client/tests/common/mod.rs +++ b/mountpoint-s3-client/tests/common/mod.rs @@ -67,6 +67,20 @@ pub fn get_test_bucket() -> String { } } +/// An S3 Express bucket with SSE-S3 set as the default encryption +#[cfg(feature = "s3express_tests")] +pub fn get_express_bucket() -> String { + std::env::var("S3_EXPRESS_ONE_ZONE_BUCKET_NAME") + .expect("Set S3_EXPRESS_ONE_ZONE_BUCKET_NAME to run integration tests") +} + +/// An S3 Express bucket with SSE-KMS set as a default encryption with a key matching the `KMS_TEST_KEY_ID` +#[cfg(feature = "s3express_tests")] +pub fn get_express_sse_kms_bucket() -> String { + std::env::var("S3_EXPRESS_ONE_ZONE_BUCKET_NAME_SSE_KMS") + .expect("Set S3_EXPRESS_ONE_ZONE_BUCKET_NAME_SSE_KMS to run integration tests") +} + pub fn get_test_kms_key_id() -> String { std::env::var("KMS_TEST_KEY_ID").expect("Set KMS_TEST_KEY_ID to run integration tests") } diff --git a/mountpoint-s3-client/tests/get_object.rs b/mountpoint-s3-client/tests/get_object.rs index 9abfae092..485492da1 100644 --- a/mountpoint-s3-client/tests/get_object.rs +++ b/mountpoint-s3-client/tests/get_object.rs @@ -551,3 +551,59 @@ async fn test_get_object_checksum_checksums_disabled() { .await .expect_err("should not return a checksum object as not requested"); } + +#[test_case("aws:kms", Some(get_test_kms_key_id()))] +#[test_case("aws:kms:dsse", Some(get_test_kms_key_id()))] +#[test_case("AES256", None)] +#[tokio::test] +#[cfg(not(feature = "s3express_tests"))] +async fn test_get_object_sse(sse_type: &str, kms_key_id: Option) { + test_get_object_sse_base(sse_type, kms_key_id, get_test_bucket()).await; +} + +// We have a separate set of tests for express because: +// 1. via SDK / CRT we can only put an object with the settings that match bucket's defaults: +// https://docs.aws.amazon.com/AmazonS3/latest/userguide/s3-express-specifying-kms-encryption.html +// 2. Express doesn't currently support `aws:kms:dsse`. +#[test_case("AES256", None, get_express_bucket())] +#[test_case("aws:kms", Some(get_test_kms_key_id()), get_express_sse_kms_bucket())] +#[tokio::test] +#[cfg(feature = "s3express_tests")] +async fn test_get_object_sse(sse_type: &str, kms_key_id: Option, bucket: String) { + test_get_object_sse_base(sse_type, kms_key_id, bucket).await; +} + +async fn test_get_object_sse_base(sse_type: &str, kms_key_id: Option, bucket: String) { + let sdk_client = get_test_sdk_client().await; + let prefix = get_unique_test_prefix("test_get_object_sse"); + + let key = format!("{prefix}/test"); + let body = vec![0x42; 42]; + let mut request = sdk_client + .put_object() + .bucket(&bucket) + .key(&key) + .body(ByteStream::from(body.clone())) + .server_side_encryption(sse_type.into()); + if let Some(kms_key_id) = kms_key_id.as_ref() { + request = request.ssekms_key_id(kms_key_id) + } + request.send().await.unwrap(); + + let client: S3CrtClient = get_test_client(); + + let result = client + .get_object( + &bucket, + &key, + &GetObjectParams::new().checksum_mode(Some(ChecksumMode::Enabled)), + ) + .await + .expect("get_object should succeed"); + + let (received_sse, received_key_id) = result.get_object_sse().await.expect("should return sse settings"); + assert_eq!(received_sse.expect("sse type must be some"), sse_type); + if let Some(kms_key_id) = kms_key_id { + assert_eq!(received_key_id.expect("sse key must be some"), kms_key_id); + } +} diff --git a/mountpoint-s3-client/tests/put_object_single.rs b/mountpoint-s3-client/tests/put_object_single.rs index 28d8dda69..e754e40d5 100644 --- a/mountpoint-s3-client/tests/put_object_single.rs +++ b/mountpoint-s3-client/tests/put_object_single.rs @@ -205,7 +205,6 @@ async fn test_put_object_storage_class(storage_class: &str) { assert_eq!(storage_class, attributes.storage_class.unwrap().as_str()); } -#[cfg(not(feature = "s3express_tests"))] async fn check_sse( bucket: &String, key: &String, @@ -292,6 +291,52 @@ async fn test_put_object_sse(sse_type: Option<&str>, kms_key_id: Option) check_sse(&bucket, &key, sse_type, &kms_key_id, put_object_result).await; } +#[test_case(Some("aws:kms"), Some(get_test_kms_key_id()), get_express_sse_kms_bucket(), false)] +#[test_case(Some("aws:kms"), None, get_express_sse_kms_bucket(), false)] +#[test_case(Some("aws:kms"), Some(get_test_kms_key_id()), get_express_bucket(), true)] +#[test_case(Some("aws:kms"), None, get_express_bucket(), true)] +#[test_case(Some("AES256"), None, get_express_bucket(), false)] +#[test_case(Some("AES256"), None, get_express_sse_kms_bucket(), true)] +#[test_case(None, None, get_express_bucket(), false)] +#[tokio::test] +#[cfg(feature = "s3express_tests")] +async fn test_put_object_sse(sse_type: Option<&str>, kms_key_id: Option, bucket: String, should_fail: bool) { + let client_config = S3ClientConfig::new().endpoint_config(get_test_endpoint_config()); + let client = S3CrtClient::new(client_config).expect("could not create test client"); + let request_params = PutObjectSingleParams::new() + .server_side_encryption(sse_type.map(|value| value.to_owned())) + .ssekms_key_id(kms_key_id.to_owned()); + + // Make a request + let prefix = get_unique_test_prefix("test_put_object_sse"); + let key = format!("{prefix}hello"); + + let mut rng = rand::thread_rng(); + let mut contents = vec![0u8; 32]; + rng.fill(&mut contents[..]); + + let put_object_result = client + .put_object_single(&bucket, &key, &request_params, &contents) + .await; + + if should_fail { + assert!(put_object_result.is_err(), "put request should fail"); + return; + } else { + assert!(put_object_result.is_ok(), "put request should succeed"); + } + + // Check sse of the object via SDK and the values returned in response to the PUT + check_sse( + &bucket, + &key, + sse_type, + &kms_key_id, + put_object_result.expect("put request should succeed"), + ) + .await; +} + #[tokio::test] async fn test_put_object_header() { let (bucket, prefix) = get_test_bucket_and_prefix("test_put_object_header"); diff --git a/mountpoint-s3/src/cli.rs b/mountpoint-s3/src/cli.rs index f0ead18c0..1c3dd0efc 100644 --- a/mountpoint-s3/src/cli.rs +++ b/mountpoint-s3/src/cli.rs @@ -894,7 +894,13 @@ where match (args.disk_data_cache_config(), args.express_data_cache_config()) { (None, Some((config, bucket_name, cache_bucket_name))) => { tracing::trace!("using S3 Express One Zone bucket as a cache for object content"); - let express_cache = ExpressDataCache::new(client.clone(), config, bucket_name, cache_bucket_name); + let express_cache = ExpressDataCache::new( + client.clone(), + config, + bucket_name, + cache_bucket_name, + filesystem_config.server_side_encryption.clone(), + ); let prefetcher = caching_prefetch(express_cache, runtime, prefetcher_config); let fuse_session = create_filesystem( @@ -933,7 +939,13 @@ where (Some((disk_data_cache_config, cache_dir_path)), Some((config, bucket_name, cache_bucket_name))) => { tracing::trace!("using both local disk and S3 Express One Zone bucket as a cache for object content"); let (managed_cache_dir, disk_cache) = create_disk_cache(cache_dir_path, disk_data_cache_config)?; - let express_cache = ExpressDataCache::new(client.clone(), config, bucket_name, cache_bucket_name); + let express_cache = ExpressDataCache::new( + client.clone(), + config, + bucket_name, + cache_bucket_name, + filesystem_config.server_side_encryption.clone(), + ); let cache = MultilevelDataCache::new(Arc::new(disk_cache), express_cache, runtime.clone()); let prefetcher = caching_prefetch(cache, runtime, prefetcher_config); diff --git a/mountpoint-s3/src/data_cache/express_data_cache.rs b/mountpoint-s3/src/data_cache/express_data_cache.rs index a9fac357a..cc37706cc 100644 --- a/mountpoint-s3/src/data_cache/express_data_cache.rs +++ b/mountpoint-s3/src/data_cache/express_data_cache.rs @@ -1,4 +1,5 @@ use crate::object::ObjectId; +use crate::ServerSideEncryption; use super::{BlockIndex, ChecksummedBytes, DataCache, DataCacheError, DataCacheResult}; @@ -6,10 +7,10 @@ use async_trait::async_trait; use bytes::BytesMut; use futures::{pin_mut, StreamExt}; use mountpoint_s3_client::error::{GetObjectError, ObjectClientError}; -use mountpoint_s3_client::types::{GetObjectParams, GetObjectRequest, PutObjectParams}; -use mountpoint_s3_client::{ObjectClient, PutObjectRequest}; +use mountpoint_s3_client::types::{GetObjectParams, GetObjectRequest, PutObjectSingleParams}; +use mountpoint_s3_client::ObjectClient; use sha2::{Digest, Sha256}; -use tracing::Instrument; +use tracing::{error, Instrument}; const CACHE_VERSION: &str = "V1"; @@ -38,6 +39,7 @@ pub struct ExpressDataCache { config: ExpressDataCacheConfig, /// Name of the S3 Express bucket to store the blocks. bucket_name: String, + sse: ServerSideEncryption, } impl From> for DataCacheError @@ -57,7 +59,13 @@ where /// Create a new instance. /// /// TODO: consider adding some validation of the bucket. - pub fn new(client: Client, config: ExpressDataCacheConfig, source_bucket_name: &str, bucket_name: &str) -> Self { + pub fn new( + client: Client, + config: ExpressDataCacheConfig, + source_bucket_name: &str, + bucket_name: &str, + sse: ServerSideEncryption, + ) -> Self { let prefix = hex::encode( Sha256::new() .chain_update(CACHE_VERSION.as_bytes()) @@ -65,11 +73,13 @@ where .chain_update(source_bucket_name.as_bytes()) .finalize(), ); + Self { client, prefix, config, bucket_name: bucket_name.to_owned(), + sse, } } } @@ -126,6 +136,15 @@ where Err(e) => return Err(e.into()), } } + + let (sse_type, sse_kms_key_id) = result + .get_object_sse() + .await + .map_err(|err| DataCacheError::IoFailure(err.into()))?; + self.sse + .verify_response(sse_type.as_deref(), sse_kms_key_id.as_deref()) + .map_err(|err| DataCacheError::IoFailure(err.into()))?; + let buffer = buffer.freeze(); DataCacheResult::Ok(Some(buffer.into())) } @@ -148,16 +167,35 @@ where let object_key = block_key(&self.prefix, &cache_key, block_idx); - // TODO: ideally we should use a simple Put rather than MPU. - let params = PutObjectParams::new(); - let mut req = self + let mut params = PutObjectSingleParams::new(); + let (sse_type, key_id) = self + .sse + .clone() + .into_inner() + .map_err(|err| DataCacheError::IoFailure(err.into()))?; + params = params.server_side_encryption(sse_type); + params = params.ssekms_key_id(key_id); + + let (data, _crc) = bytes.into_inner().map_err(|_| DataCacheError::InvalidBlockContent)?; + let req = self .client - .put_object(&self.bucket_name, &object_key, ¶ms) + .put_object_single(&self.bucket_name, &object_key, ¶ms, data) .in_current_span() .await?; - let (data, _crc) = bytes.into_inner().map_err(|_| DataCacheError::InvalidBlockContent)?; - req.write(&data).await?; - req.complete().await?; + + // Verify that headers of the PUT response match the expected SSE + if let Err(err) = self + .sse + .verify_response(req.sse_type.as_deref(), req.sse_kms_key_id.as_deref()) + { + error!(key=?cache_key, error=?err, "A cache block was stored with wrong encryption settings"); + // Reaching this point is very unlikely and means that SSE settings were corrupted in transit or on S3 side, this may be a sign of a bug + // in CRT code or S3. Thus, we terminate Mountpoint to send the most noticeable signal to customer about the issue. We prefer exiting + // instead of returning an error because: + // 1. this error would only be reported to logs because the cache population is an async process + // 2. the reported error is severe as the object was already uploaded to S3. + std::process::exit(1); + } DataCacheResult::Ok(()) } @@ -206,7 +244,7 @@ mod tests { block_size, ..Default::default() }; - let cache = ExpressDataCache::new(client, config, "unique source description", bucket); + let cache = ExpressDataCache::new(client, config, "unique source description", bucket, Default::default()); let data_1 = ChecksummedBytes::new("Foo".into()); let data_2 = ChecksummedBytes::new("Bar".into()); @@ -299,7 +337,13 @@ mod tests { ..Default::default() }; let client = Arc::new(MockClient::new(config)); - let cache = ExpressDataCache::new(client.clone(), Default::default(), "unique source description", bucket); + let cache = ExpressDataCache::new( + client.clone(), + Default::default(), + "unique source description", + bucket, + Default::default(), + ); let data_1 = vec![0u8; 1024 * 1024 + 1]; let data_1 = ChecksummedBytes::new(data_1.into()); let cache_key_1 = ObjectId::new("a".into(), ETag::for_tests()); diff --git a/mountpoint-s3/src/data_cache/multilevel_cache.rs b/mountpoint-s3/src/data_cache/multilevel_cache.rs index 403b23313..c0247be4c 100644 --- a/mountpoint-s3/src/data_cache/multilevel_cache.rs +++ b/mountpoint-s3/src/data_cache/multilevel_cache.rs @@ -151,7 +151,13 @@ mod tests { ..Default::default() }; let client = MockClient::new(config); - let cache = ExpressDataCache::new(client.clone(), Default::default(), "unique source description", bucket); + let cache = ExpressDataCache::new( + client.clone(), + Default::default(), + "unique source description", + bucket, + Default::default(), + ); (client, cache) } diff --git a/mountpoint-s3/tests/common/cache.rs b/mountpoint-s3/tests/common/cache.rs new file mode 100644 index 000000000..f4d52b657 --- /dev/null +++ b/mountpoint-s3/tests/common/cache.rs @@ -0,0 +1,126 @@ +use std::{ + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, + }, + time::Duration, +}; + +use async_trait::async_trait; +use mountpoint_s3::{ + data_cache::{BlockIndex, ChecksummedBytes, DataCache, DataCacheResult}, + object::ObjectId, +}; + +/// A wrapper around any type implementing [DataCache], which counts operations +pub struct CacheTestWrapper { + cache: Arc, + get_block_ok_count: Arc, + get_block_hit_count: Arc, + get_block_failed_count: Arc, + put_block_ok_count: Arc, + put_block_failed_count: Arc, +} + +impl Clone for CacheTestWrapper { + fn clone(&self) -> Self { + Self { + cache: self.cache.clone(), + get_block_ok_count: self.get_block_ok_count.clone(), + get_block_hit_count: self.get_block_hit_count.clone(), + get_block_failed_count: self.get_block_failed_count.clone(), + put_block_ok_count: self.put_block_ok_count.clone(), + put_block_failed_count: self.put_block_failed_count.clone(), + } + } +} + +impl CacheTestWrapper { + pub fn new(cache: Arc) -> Self { + CacheTestWrapper { + cache, + get_block_ok_count: Arc::new(AtomicU64::new(0)), + get_block_hit_count: Arc::new(AtomicU64::new(0)), + get_block_failed_count: Arc::new(AtomicU64::new(0)), + put_block_ok_count: Arc::new(AtomicU64::new(0)), + put_block_failed_count: Arc::new(AtomicU64::new(0)), + } + } + + pub fn wait_for_put(&self, max_wait_duration: Duration) { + let st = std::time::Instant::now(); + loop { + if st.elapsed() > max_wait_duration { + panic!("timeout on waiting for a write to the cache to happen") + } + if self.put_block_failed_count.load(Ordering::SeqCst) > 0 + || self.put_block_ok_count.load(Ordering::SeqCst) > 0 + { + break; + } + std::thread::sleep(Duration::from_millis(100)); + } + } + + pub fn get_block_hit_count(&self) -> u64 { + self.get_block_hit_count.load(Ordering::SeqCst) + } + + pub fn put_block_ok_count(&self) -> u64 { + self.put_block_ok_count.load(Ordering::SeqCst) + } + + pub fn put_block_failed_count(&self) -> u64 { + self.put_block_failed_count.load(Ordering::SeqCst) + } +} + +#[async_trait] +impl DataCache for CacheTestWrapper { + async fn get_block( + &self, + cache_key: &ObjectId, + block_idx: BlockIndex, + block_offset: u64, + object_size: usize, + ) -> DataCacheResult> { + let result = self + .cache + .get_block(cache_key, block_idx, block_offset, object_size) + .await + .inspect(|_| { + self.get_block_ok_count.fetch_add(1, Ordering::SeqCst); + }) + .inspect_err(|_| { + self.get_block_failed_count.fetch_add(1, Ordering::SeqCst); + })? + .inspect(|_| { + self.get_block_hit_count.fetch_add(1, Ordering::SeqCst); + }); + + Ok(result) + } + + async fn put_block( + &self, + cache_key: ObjectId, + block_idx: BlockIndex, + block_offset: u64, + bytes: ChecksummedBytes, + object_size: usize, + ) -> DataCacheResult<()> { + self.cache + .put_block(cache_key, block_idx, block_offset, bytes, object_size) + .await + .inspect(|_| { + self.put_block_ok_count.fetch_add(1, Ordering::SeqCst); + }) + .inspect_err(|_| { + self.put_block_failed_count.fetch_add(1, Ordering::SeqCst); + }) + } + + fn block_size(&self) -> u64 { + self.cache.block_size() + } +} diff --git a/mountpoint-s3/tests/common/fuse.rs b/mountpoint-s3/tests/common/fuse.rs index 48dac3052..d65519a6b 100644 --- a/mountpoint-s3/tests/common/fuse.rs +++ b/mountpoint-s3/tests/common/fuse.rs @@ -113,7 +113,7 @@ pub trait TestSessionCreator: FnOnce(&str, TestSessionConfig) -> TestSession {} // `FnOnce(...)` in place of `impl TestSessionCreator`. impl TestSessionCreator for T where T: FnOnce(&str, TestSessionConfig) -> TestSession {} -fn create_fuse_session( +pub fn create_fuse_session( client: Client, prefetcher: Prefetcher, bucket: &str, @@ -363,12 +363,7 @@ pub mod s3_session { let (bucket, prefix) = get_test_bucket_and_prefix(test_name); let region = get_test_region(); - let client_config = S3ClientConfig::default() - .part_size(test_config.part_size) - .endpoint_config(get_test_endpoint_config()) - .read_backpressure(true) - .initial_read_window(test_config.initial_read_window_size); - let client = S3CrtClient::new(client_config).unwrap(); + let client = create_crt_client(test_config.part_size, test_config.initial_read_window_size); let runtime = client.event_loop_group(); let prefetcher = caching_prefetch(cache, runtime, test_config.prefetcher_config); let session = create_fuse_session( @@ -385,7 +380,16 @@ pub mod s3_session { } } - fn create_test_client(region: &str, bucket: &str, prefix: &str) -> impl TestClient { + pub fn create_crt_client(part_size: usize, initial_read_window_size: usize) -> S3CrtClient { + let client_config = S3ClientConfig::default() + .part_size(part_size) + .endpoint_config(get_test_endpoint_config()) + .read_backpressure(true) + .initial_read_window(initial_read_window_size); + S3CrtClient::new(client_config).unwrap() + } + + pub fn create_test_client(region: &str, bucket: &str, prefix: &str) -> impl TestClient { let sdk_client = tokio_block_on(async { get_test_sdk_client(region).await }); SDKTestClient { prefix: prefix.to_owned(), diff --git a/mountpoint-s3/tests/common/mod.rs b/mountpoint-s3/tests/common/mod.rs index f60d3ec10..75067f19c 100644 --- a/mountpoint-s3/tests/common/mod.rs +++ b/mountpoint-s3/tests/common/mod.rs @@ -2,6 +2,8 @@ //! Allow for unused items since this is included independently in each module. #![allow(dead_code)] +pub mod cache; + pub mod creds; #[cfg(feature = "fuse_tests")] diff --git a/mountpoint-s3/tests/common/s3.rs b/mountpoint-s3/tests/common/s3.rs index f04fad003..cb6fc3573 100644 --- a/mountpoint-s3/tests/common/s3.rs +++ b/mountpoint-s3/tests/common/s3.rs @@ -8,13 +8,17 @@ use rand_chacha::rand_core::OsRng; use crate::common::tokio_block_on; pub fn get_test_bucket_and_prefix(test_name: &str) -> (String, String) { - let bucket = if cfg!(feature = "s3express_tests") { - std::env::var("S3_EXPRESS_ONE_ZONE_BUCKET_NAME") - .expect("Set S3_EXPRESS_ONE_ZONE_BUCKET_NAME to run integration tests") - } else { - std::env::var("S3_BUCKET_NAME").expect("Set S3_BUCKET_NAME to run integration tests") - }; + #[cfg(not(feature = "s3express_tests"))] + let bucket = get_standard_bucket(); + #[cfg(feature = "s3express_tests")] + let bucket = get_express_bucket(); + + let prefix = get_test_prefix(test_name); + + (bucket, prefix) +} +pub fn get_test_prefix(test_name: &str) -> String { // Generate a random nonce to make sure this prefix is truly unique let nonce = OsRng.next_u64(); @@ -22,9 +26,17 @@ pub fn get_test_bucket_and_prefix(test_name: &str) -> (String, String) { let prefix = std::env::var("S3_BUCKET_TEST_PREFIX").unwrap_or(String::from("mountpoint-test/")); assert!(prefix.ends_with('/'), "S3_BUCKET_TEST_PREFIX should end in '/'"); - let prefix = format!("{prefix}{test_name}/{nonce}/"); + format!("{prefix}{test_name}/{nonce}/") +} - (bucket, prefix) +#[cfg(feature = "s3express_tests")] +pub fn get_express_bucket() -> String { + std::env::var("S3_EXPRESS_ONE_ZONE_BUCKET_NAME") + .expect("Set S3_EXPRESS_ONE_ZONE_BUCKET_NAME to run integration tests") +} + +pub fn get_standard_bucket() -> String { + std::env::var("S3_BUCKET_NAME").expect("Set S3_BUCKET_NAME to run integration tests") } pub fn get_test_bucket_forbidden() -> String { diff --git a/mountpoint-s3/tests/fuse_tests/cache_test.rs b/mountpoint-s3/tests/fuse_tests/cache_test.rs new file mode 100644 index 000000000..c14b24537 --- /dev/null +++ b/mountpoint-s3/tests/fuse_tests/cache_test.rs @@ -0,0 +1,99 @@ +use std::{fs, sync::Arc, time::Duration}; + +use mountpoint_s3::{data_cache::ExpressDataCache, prefetch::caching_prefetch, ServerSideEncryption}; +use rand::{Rng, RngCore, SeedableRng}; +use rand_chacha::ChaChaRng; +use test_case::test_case; + +use crate::common::{ + cache::CacheTestWrapper, + fuse::{create_fuse_session, s3_session::create_crt_client}, + s3::{get_express_bucket, get_standard_bucket, get_test_prefix}, +}; + +const CLIENT_PART_SIZE: usize = 8 * 1024 * 1024; + +/// We want data to be stored in the cache with the provided SSE settings. +/// In some cases this is not possible, thus we expect a failure. +#[test_case("aws:sse", true; "Invalid SSE (does not match the default)")] +#[test_case("AES256", false; "Valid SSE (matches the default)")] +fn express_cache_enforced_sse_on_put(sse: &str, should_fail: bool) { + let bucket_name = get_standard_bucket(); + let prefix = get_test_prefix("express_cache_bad_sse"); + let express_bucket_name = get_express_bucket(); + let client = create_crt_client(CLIENT_PART_SIZE, CLIENT_PART_SIZE); + let express_cache = ExpressDataCache::new( + client.clone(), + Default::default(), + &bucket_name, + &express_bucket_name, + ServerSideEncryption::new(Some(sse.to_owned()), None), + ); + let express_cache = CacheTestWrapper::new(Arc::new(express_cache)); + + // Mount a bucket + let mount_point = tempfile::tempdir().unwrap(); + let runtime = client.event_loop_group(); + let prefetcher = caching_prefetch(express_cache.clone(), runtime, Default::default()); + let _session = create_fuse_session( + client, + prefetcher, + &bucket_name, + &prefix, + mount_point.path(), + Default::default(), + ); + + // Write an object, no caching happens yet + let key = get_object_key(&prefix, "key", 100); + let path = mount_point.path().join(&key); + let written = random_binary_data(1024 * 1024); + fs::write(&path, &written).expect("write should succeed"); + + // First read should be from the source bucket and not be cache as the SSE can not be enforced + let read = fs::read(&path).expect("read should succeed"); + assert_eq!(read, written); + + // Cache writes are async, wait for that to happen + express_cache.wait_for_put(Duration::from_secs(10)); + + // Depending on the test case check that either or writes failed or all were successful + if should_fail { + assert_eq!(express_cache.put_block_ok_count(), 0) + } else { + assert_eq!(express_cache.put_block_failed_count(), 0); + } + + // TODO: check with sdk client that data is stored with the right settings or not stored at all +} + +// #[test] +// fn express_cache_enforced_sse_on_get(); + +// #[test] +// fn express_cache_expected_bucket_owner_on_get(); + +// #[test] +// fn express_cache_expected_bucket_owner_on_put(); + +// #[test] +// fn express_cache_wrong_etag(); + +fn random_binary_data(size_in_bytes: usize) -> Vec { + let seed = rand::thread_rng().gen(); + let mut rng = ChaChaRng::seed_from_u64(seed); + let mut data = vec![0; size_in_bytes]; + rng.fill_bytes(&mut data); + data +} + +// Creates a random key which has a size of at least `min_size_in_bytes` +fn get_object_key(key_prefix: &str, key_suffix: &str, min_size_in_bytes: usize) -> String { + let random_suffix: u64 = rand::thread_rng().gen(); + let last_key_part = format!("{key_suffix}{random_suffix}"); // part of the key after all the "/" + let full_key = format!("{key_prefix}{last_key_part}"); + let full_key_size = full_key.as_bytes().len(); + let padding_size = min_size_in_bytes.saturating_sub(full_key_size); + let padding = "0".repeat(padding_size); + format!("{last_key_part}{padding}") +} diff --git a/mountpoint-s3/tests/fuse_tests/mod.rs b/mountpoint-s3/tests/fuse_tests/mod.rs index 2b24c0d1e..f4347439e 100644 --- a/mountpoint-s3/tests/fuse_tests/mod.rs +++ b/mountpoint-s3/tests/fuse_tests/mod.rs @@ -1,3 +1,5 @@ +#[cfg(all(feature = "s3_tests", feature = "s3express_tests"))] +mod cache_test; mod consistency_test; mod fork_test; mod lookup_test;