diff --git a/mountpoint-s3-client/CHANGELOG.md b/mountpoint-s3-client/CHANGELOG.md index 5a46bb713..fa0d1796a 100644 --- a/mountpoint-s3-client/CHANGELOG.md +++ b/mountpoint-s3-client/CHANGELOG.md @@ -1,3 +1,13 @@ +## Unreleased + +### Breaking changes + +* When using GetObject with backpressure enabled, an error will be returned when there is not enough read window instead of blocking. ([#971](https://github.com/awslabs/mountpoint-s3/pull/971)) + +### Other changes + +* Allow querying initial read window size and read window end offset for backpressure GetObject. ([#971](https://github.com/awslabs/mountpoint-s3/pull/971)) + ## v0.9.0 (June 26, 2024) * Adds support for `AWS_ENDPOINT_URL` environment variable. ([#895](https://github.com/awslabs/mountpoint-s3/pull/895)) diff --git a/mountpoint-s3-client/src/failure_client.rs b/mountpoint-s3-client/src/failure_client.rs index e0f2b43e7..b7d78a25d 100644 --- a/mountpoint-s3-client/src/failure_client.rs +++ b/mountpoint-s3-client/src/failure_client.rs @@ -81,6 +81,10 @@ where self.client.write_part_size() } + fn initial_read_window_size(&self) -> Option { + self.client.initial_read_window_size() + } + async fn delete_object( &self, bucket: &str, @@ -188,6 +192,11 @@ impl GetObjectRequest for FailureGetReque let this = self.project(); this.request.increment_read_window(len); } + + fn read_window_end_offset(self: Pin<&Self>) -> u64 { + let this = self.project_ref(); + this.request.read_window_end_offset() + } } impl Stream for FailureGetRequest { diff --git a/mountpoint-s3-client/src/mock_client.rs b/mountpoint-s3-client/src/mock_client.rs index 902b27c96..b663a9463 100644 --- a/mountpoint-s3-client/src/mock_client.rs +++ b/mountpoint-s3-client/src/mock_client.rs @@ -64,7 +64,7 @@ pub struct MockClientConfig { /// A seed to randomize the order of ListObjectsV2 results, or None to use ordered list pub unordered_list_seed: Option, /// A flag to enable backpressure read - pub enable_back_pressure: bool, + pub enable_backpressure: bool, /// Initial backpressure read window size, ignored if enable_back_pressure is false pub initial_read_window_size: usize, } @@ -475,8 +475,8 @@ pub struct MockGetObjectRequest { next_offset: u64, length: usize, part_size: usize, - enable_back_pressure: bool, - current_window_size: usize, + enable_backpressure: bool, + read_window_end_offset: u64, } impl MockGetObjectRequest { @@ -498,7 +498,11 @@ impl GetObjectRequest for MockGetObjectRequest { type ClientError = MockClientError; fn increment_read_window(mut self: Pin<&mut Self>, len: usize) { - self.current_window_size += len; + self.read_window_end_offset += len as u64; + } + + fn read_window_end_offset(self: Pin<&Self>) -> u64 { + self.read_window_end_offset } } @@ -510,15 +514,13 @@ impl Stream for MockGetObjectRequest { return Poll::Ready(None); } - let mut next_read_size = self.part_size.min(self.length); + let next_read_size = self.part_size.min(self.length); // Simulate backpressure mechanism - if self.enable_back_pressure { - if self.current_window_size == 0 { - return Poll::Pending; - } - next_read_size = self.current_window_size.min(next_read_size); - self.current_window_size -= next_read_size; + if self.enable_backpressure && self.next_offset >= self.read_window_end_offset { + return Poll::Ready(Some(Err(ObjectClientError::ClientError(MockClientError( + "empty read window".into(), + ))))); } let next_part = self.object.read(self.next_offset, next_read_size); @@ -562,6 +564,14 @@ impl ObjectClient for MockClient { Some(self.config.part_size) } + fn initial_read_window_size(&self) -> Option { + if self.config.enable_backpressure { + Some(self.config.initial_read_window_size) + } else { + None + } + } + async fn delete_object( &self, bucket: &str, @@ -616,8 +626,8 @@ impl ObjectClient for MockClient { next_offset, length, part_size: self.config.part_size, - enable_back_pressure: self.config.enable_back_pressure, - current_window_size: self.config.initial_read_window_size, + enable_backpressure: self.config.enable_backpressure, + read_window_end_offset: next_offset + self.config.initial_read_window_size as u64, }) } else { Err(ObjectClientError::ServiceError(GetObjectError::NoSuchKey)) @@ -908,11 +918,6 @@ enum MockObjectParts { #[cfg(test)] mod tests { - use std::{ - sync::mpsc::{self, RecvTimeoutError}, - thread, - }; - use futures::{pin_mut, StreamExt}; use rand::{Rng, RngCore, SeedableRng}; use rand_chacha::ChaChaRng; @@ -920,6 +925,18 @@ mod tests { use super::*; + macro_rules! assert_client_error { + ($e:expr, $err:expr) => { + let err = $e.expect_err("should fail"); + match err { + ObjectClientError::ClientError(MockClientError(m)) => { + assert_eq!(&*m, $err); + } + _ => assert!(false, "wrong error type"), + } + }; + } + async fn test_get_object(key: &str, size: usize, range: Option>) { let mut rng = ChaChaRng::seed_from_u64(0x12345678); @@ -971,7 +988,7 @@ mod tests { bucket: "test_bucket".to_string(), part_size: 1024, unordered_list_seed: None, - enable_back_pressure: true, + enable_backpressure: true, initial_read_window_size: backpressure_read_window_size, }); @@ -992,9 +1009,12 @@ mod tests { assert_eq!(offset, next_offset, "wrong body part offset"); next_offset += body.len() as u64; accum.extend_from_slice(&body[..]); - get_request - .as_mut() - .increment_read_window(backpressure_read_window_size); + + while next_offset >= get_request.as_ref().read_window_end_offset() { + get_request + .as_mut() + .increment_read_window(backpressure_read_window_size); + } } let expected_range = range.unwrap_or(0..size as u64); let expected_range = expected_range.start as usize..expected_range.end as usize; @@ -1024,18 +1044,6 @@ mod tests { rng.fill_bytes(&mut body); client.add_object("key1", body[..].into()); - macro_rules! assert_client_error { - ($e:expr, $err:expr) => { - let err = $e.expect_err("should fail"); - match err { - ObjectClientError::ClientError(MockClientError(m)) => { - assert_eq!(&*m, $err); - } - _ => assert!(false, "wrong error type"), - } - }; - } - assert!(matches!( client.get_object("wrong_bucket", "key1", None, None).await, Err(ObjectClientError::ServiceError(GetObjectError::NoSuchBucket)) @@ -1068,53 +1076,45 @@ mod tests { ); } - // Verify that the request is blocked when we don't increment read window size + // Verify that an error is returned when we don't increment read window size #[tokio::test] async fn verify_backpressure_get_object() { let key = "key1"; - let size = 1000; - let range = 50..1000; - let mut rng = ChaChaRng::seed_from_u64(0x12345678); + let mut rng = ChaChaRng::seed_from_u64(0x12345678); let client = MockClient::new(MockClientConfig { bucket: "test_bucket".to_string(), part_size: 1024, unordered_list_seed: None, - enable_back_pressure: true, + enable_backpressure: true, initial_read_window_size: 256, }); - let mut body = vec![0u8; size]; - rng.fill_bytes(&mut body); - client.add_object(key, MockObject::from_bytes(&body, ETag::for_tests())); + let part_size = client.read_part_size().unwrap(); + let size = part_size * 2; + let range = 0..(part_size + 1) as u64; + + let mut expected_body = vec![0u8; size]; + rng.fill_bytes(&mut expected_body); + client.add_object(key, MockObject::from_bytes(&expected_body, ETag::for_tests())); let mut get_request = client .get_object("test_bucket", key, Some(range.clone()), None) .await .expect("should not fail"); - let mut accum = vec![]; - let mut next_offset = range.start; - - let (sender, receiver) = mpsc::channel(); - thread::spawn(move || { - futures::executor::block_on(async move { - while let Some(r) = get_request.next().await { - let (offset, body) = r.unwrap(); - assert_eq!(offset, next_offset, "wrong body part offset"); - next_offset += body.len() as u64; - accum.extend_from_slice(&body[..]); - } - let expected_range = range; - let expected_range = expected_range.start as usize..expected_range.end as usize; - assert_eq!(&accum[..], &body[expected_range], "body does not match"); - sender.send(accum).unwrap(); - }) - }); - match receiver.recv_timeout(Duration::from_millis(100)) { - Ok(_) => panic!("request should have been blocked"), - Err(e) => assert_eq!(e, RecvTimeoutError::Timeout), - } + // Verify that we can receive some data since the window size is more than 0 + let first_part = get_request.next().await.expect("result should not be empty"); + let (offset, body) = first_part.unwrap(); + assert_eq!(offset, 0, "wrong body part offset"); + + // The CRT always return at least a part even if the window is smaller than that + let expected_range = range.start as usize..part_size; + assert_eq!(&body[..], &expected_body[expected_range]); + + // This await should return an error because current window is not enough to get the next part + let next = get_request.next().await.expect("result should not be empty"); + assert_client_error!(next, "empty read window"); } #[tokio::test] diff --git a/mountpoint-s3-client/src/mock_client/throughput_client.rs b/mountpoint-s3-client/src/mock_client/throughput_client.rs index e3089ced4..a55c2b491 100644 --- a/mountpoint-s3-client/src/mock_client/throughput_client.rs +++ b/mountpoint-s3-client/src/mock_client/throughput_client.rs @@ -72,6 +72,11 @@ impl GetObjectRequest for ThroughputGetObjectRequest { let this = self.project(); this.request.increment_read_window(len); } + + fn read_window_end_offset(self: Pin<&Self>) -> u64 { + let this = self.project_ref(); + this.request.read_window_end_offset() + } } impl Stream for ThroughputGetObjectRequest { @@ -105,6 +110,10 @@ impl ObjectClient for ThroughputMockClient { self.inner.write_part_size() } + fn initial_read_window_size(&self) -> Option { + self.inner.initial_read_window_size() + } + async fn delete_object( &self, bucket: &str, diff --git a/mountpoint-s3-client/src/object_client.rs b/mountpoint-s3-client/src/object_client.rs index 741c842d5..b92bf796c 100644 --- a/mountpoint-s3-client/src/object_client.rs +++ b/mountpoint-s3-client/src/object_client.rs @@ -85,6 +85,10 @@ pub trait ObjectClient { /// can be `None` if the client does not do multi-part operations. fn write_part_size(&self) -> Option; + /// Query the initial read window size this client uses for backpressure GetObject requests. + /// This can be `None` if backpressure is disabled. + fn initial_read_window_size(&self) -> Option; + /// Delete a single object from the object store. /// /// DeleteObject will succeed even if the object within the bucket does not exist. @@ -378,6 +382,10 @@ pub trait GetObjectRequest: /// If `enable_read_backpressure` is false this call will have no effect, /// no backpressure is being applied and data is being downloaded as fast as possible. fn increment_read_window(self: Pin<&mut Self>, len: usize); + + /// Get the upper bound of the current read window. When backpressure is enabled, [GetObjectRequest] can + /// return data up to this offset *exclusively*. + fn read_window_end_offset(self: Pin<&Self>) -> u64; } /// A streaming put request which allows callers to asynchronously write the body of the request. diff --git a/mountpoint-s3-client/src/s3_crt_client.rs b/mountpoint-s3-client/src/s3_crt_client.rs index 25747ef96..3a10391e9 100644 --- a/mountpoint-s3-client/src/s3_crt_client.rs +++ b/mountpoint-s3-client/src/s3_crt_client.rs @@ -36,12 +36,10 @@ use pin_project::{pin_project, pinned_drop}; use thiserror::Error; use tracing::{debug, error, trace, Span}; -use self::get_object::S3GetObjectRequest; -use self::put_object::S3PutObjectRequest; use crate::endpoint_config::EndpointError; use crate::endpoint_config::{self, EndpointConfig}; -use crate::object_client::*; use crate::user_agent::UserAgent; +use crate::{object_client::*, S3GetObjectRequest, S3PutObjectRequest}; macro_rules! request_span { ($self:expr, $method:expr, $($field:tt)*) => {{ @@ -267,6 +265,8 @@ struct S3CrtClientInner { request_payer: Option, read_part_size: usize, write_part_size: usize, + enable_backpressure: bool, + initial_read_window_size: usize, bucket_owner: Option, credentials_provider: Option, host_resolver: HostResolver, @@ -395,6 +395,8 @@ impl S3CrtClientInner { request_payer: config.request_payer, read_part_size: config.read_part_size, write_part_size: config.write_part_size, + enable_backpressure: config.read_backpressure, + initial_read_window_size: config.initial_read_window, bucket_owner: config.bucket_owner, credentials_provider: Some(credentials_provider), host_resolver, @@ -974,6 +976,12 @@ pub enum S3RequestError { /// The request was throttled by S3 #[error("Request throttled")] Throttled, + + /// Cannot fetch more data because current read window is exhausted. The read window must + /// be advanced using [GetObjectRequest::increment_read_window(u64)] to continue fetching + /// new data. + #[error("Polled for data with empty read window")] + EmptyReadWindow, } impl S3RequestError { @@ -1178,6 +1186,14 @@ impl ObjectClient for S3CrtClient { Some(self.inner.write_part_size) } + fn initial_read_window_size(&self) -> Option { + if self.inner.enable_backpressure { + Some(self.inner.initial_read_window_size) + } else { + None + } + } + async fn delete_object( &self, bucket: &str, diff --git a/mountpoint-s3-client/src/s3_crt_client/get_object.rs b/mountpoint-s3-client/src/s3_crt_client/get_object.rs index 4fe31eeb1..2a1e291ab 100644 --- a/mountpoint-s3-client/src/s3_crt_client/get_object.rs +++ b/mountpoint-s3-client/src/s3_crt_client/get_object.rs @@ -46,13 +46,16 @@ impl S3CrtClient { .map_err(S3RequestError::construction_failure)?; } - if let Some(range) = range { + let next_offset = if let Some(range) = range { // Range HTTP header is bounded below *inclusive* let range_value = format!("bytes={}-{}", range.start, range.end.saturating_sub(1)); message .set_header(&Header::new("Range", range_value)) .map_err(S3RequestError::construction_failure)?; - } + range.start + } else { + 0 + }; let key = format!("/{key}"); message @@ -60,6 +63,7 @@ impl S3CrtClient { .map_err(S3RequestError::construction_failure)?; let (sender, receiver) = futures::channel::mpsc::unbounded(); + let read_window_end_offset = next_offset + self.inner.initial_read_window_size as u64; let mut options = S3CrtClientInner::new_meta_request_options(message, S3Operation::GetObject); options.part_size(self.inner.read_part_size as u64); @@ -84,6 +88,9 @@ impl S3CrtClient { request, finish_receiver: receiver, finished: false, + enable_backpressure: self.inner.enable_backpressure, + next_offset, + read_window_end_offset, }) } } @@ -101,14 +108,25 @@ pub struct S3GetObjectRequest { #[pin] finish_receiver: UnboundedReceiver>, finished: bool, + enable_backpressure: bool, + /// Next offset of the data to be polled from [poll_next] + next_offset: u64, + /// Upper bound of the current read window. When backpressure is enabled, [S3GetObjectRequest] + /// can return data up to this offset *exclusively*. + read_window_end_offset: u64, } impl GetObjectRequest for S3GetObjectRequest { type ClientError = S3RequestError; fn increment_read_window(mut self: Pin<&mut Self>, len: usize) { + self.read_window_end_offset += len as u64; self.request.meta_request.increment_read_window(len as u64); } + + fn read_window_end_offset(self: Pin<&Self>) -> u64 { + self.read_window_end_offset + } } impl Stream for S3GetObjectRequest { @@ -122,7 +140,14 @@ impl Stream for S3GetObjectRequest { let this = self.project(); if let Poll::Ready(Some(val)) = this.finish_receiver.poll_next(cx) { - return Poll::Ready(Some(val.map_err(|e| ObjectClientError::ClientError(e.into())))); + let result = match val { + Ok(item) => { + *this.next_offset = item.0 + item.1.len() as u64; + Some(Ok(item)) + } + Err(e) => Some(Err(ObjectClientError::ClientError(e.into()))), + }; + return Poll::Ready(result); } match this.request.poll(cx) { @@ -134,7 +159,18 @@ impl Stream for S3GetObjectRequest { *this.finished = true; Poll::Ready(Some(Err(e))) } - Poll::Pending => Poll::Pending, + Poll::Pending => { + // If the request is still not finished but the read window is not enough to poll + // the next chunk we want to return error instead of keeping the request blocked. + // This prevents a risk of deadlock from using the [S3CrtClient], users must implement + // their own logic to block the request if they really want to block a [GetObjectRequest]. + if *this.enable_backpressure && this.read_window_end_offset <= this.next_offset { + return Poll::Ready(Some(Err(ObjectClientError::ClientError( + S3RequestError::EmptyReadWindow, + )))); + } + Poll::Pending + } } } } diff --git a/mountpoint-s3-client/tests/common/mod.rs b/mountpoint-s3-client/tests/common/mod.rs index 581e3d2e2..2c84cbccc 100644 --- a/mountpoint-s3-client/tests/common/mod.rs +++ b/mountpoint-s3-client/tests/common/mod.rs @@ -74,15 +74,16 @@ pub fn get_test_client() -> S3CrtClient { S3CrtClient::new(S3ClientConfig::new().endpoint_config(endpoint_config)).expect("could not create test client") } -pub fn get_test_backpressure_client(initial_read_window: usize) -> S3CrtClient { +pub fn get_test_backpressure_client(initial_read_window: usize, part_size: Option) -> S3CrtClient { let endpoint_config = EndpointConfig::new(&get_test_region()); - S3CrtClient::new( - S3ClientConfig::new() - .endpoint_config(endpoint_config) - .read_backpressure(true) - .initial_read_window(initial_read_window), - ) - .expect("could not create test client") + let mut config = S3ClientConfig::new() + .endpoint_config(endpoint_config) + .read_backpressure(true) + .initial_read_window(initial_read_window); + if let Some(part_size) = part_size { + config = config.part_size(part_size); + } + S3CrtClient::new(config).expect("could not create test client") } pub fn get_test_bucket_and_prefix(test_name: &str) -> (String, String) { @@ -203,7 +204,6 @@ pub async fn check_backpressure_get_result( range: Option>, expected: &[u8], ) { - let mut accum_read_window = read_window; let mut accum = vec![]; let mut next_offset = range.map(|r| r.start).unwrap_or(0); pin_mut!(result); @@ -215,9 +215,8 @@ pub async fn check_backpressure_get_result( // We run out of data to read if read window is smaller than accum length of data, // so we keeping adding window size, otherwise the request will be blocked. - while accum_read_window <= accum.len() { + while next_offset >= result.as_ref().read_window_end_offset() { result.as_mut().increment_read_window(read_window); - accum_read_window += read_window; } } assert_eq!(&accum[..], expected, "body does not match"); diff --git a/mountpoint-s3-client/tests/get_object.rs b/mountpoint-s3-client/tests/get_object.rs index 916bfaf2a..15a0eb28d 100644 --- a/mountpoint-s3-client/tests/get_object.rs +++ b/mountpoint-s3-client/tests/get_object.rs @@ -5,17 +5,15 @@ pub mod common; use std::ops::Range; use std::option::Option::None; use std::str::FromStr; -use std::sync::mpsc::{self, RecvTimeoutError}; -use std::thread; -use std::time::Duration; use aws_sdk_s3::primitives::ByteStream; use bytes::Bytes; use common::*; +use futures::pin_mut; use futures::stream::StreamExt; use mountpoint_s3_client::error::{GetObjectError, ObjectClientError}; -use mountpoint_s3_client::types::ETag; -use mountpoint_s3_client::{ObjectClient, S3CrtClient}; +use mountpoint_s3_client::types::{ETag, GetObjectRequest}; +use mountpoint_s3_client::{ObjectClient, S3CrtClient, S3RequestError}; use test_case::test_case; @@ -79,7 +77,7 @@ async fn test_get_object_backpressure(size: usize, range: Option>) { .unwrap(); let initial_window_size = 8 * 1024 * 1024; - let client: S3CrtClient = get_test_backpressure_client(initial_window_size); + let client: S3CrtClient = get_test_backpressure_client(initial_window_size, None); let request = client .get_object(&bucket, &key, range.clone(), None) @@ -92,17 +90,17 @@ async fn test_get_object_backpressure(size: usize, range: Option>) { check_backpressure_get_result(initial_window_size, request, range, expected).await; } -// Verify that the request is blocked when we don't increment read window size +// Verify that an error is returned when we don't increment read window size #[tokio::test] async fn verify_backpressure_get_object() { let initial_window_size = 256; - let client: S3CrtClient = get_test_backpressure_client(initial_window_size); + let client: S3CrtClient = get_test_backpressure_client(initial_window_size, None); let part_size = client.read_part_size().unwrap(); let size = part_size * 2; let range = 0..(part_size + 1) as u64; let sdk_client = get_test_sdk_client().await; - let (bucket, prefix) = get_test_bucket_and_prefix("test_get_object"); + let (bucket, prefix) = get_test_bucket_and_prefix("verify_backpressure_get_object"); let key = format!("{prefix}/test"); let expected_body = vec![0x42; size]; @@ -129,19 +127,69 @@ async fn verify_backpressure_get_object() { let expected_range = range.start as usize..part_size; assert_eq!(&body[..], &expected_body[expected_range]); - let (sender, receiver) = mpsc::channel(); - thread::spawn(move || { - futures::executor::block_on(async move { - // This await should be blocked - let second_part = get_request.next().await.unwrap(); - let (_offset, body) = second_part.unwrap(); - sender.send(body).unwrap(); - }) - }); - match receiver.recv_timeout(Duration::from_millis(1000)) { - Ok(_) => panic!("request should have been blocked"), - Err(e) => assert_eq!(e, RecvTimeoutError::Timeout), - } + // This await should return an error because current window is not enough to get the next part + let next = get_request.next().await.expect("result should not be empty"); + assert!(matches!( + next, + Err(ObjectClientError::ClientError(S3RequestError::EmptyReadWindow)) + )); +} + +#[tokio::test] +async fn test_mutated_during_get_object_backpressure() { + let part_size = 8 * 1024 * 1024; + let initial_window_size = part_size; + let client: S3CrtClient = get_test_backpressure_client(initial_window_size, Some(part_size)); + + let size = part_size * 2; + let range = 0..(part_size + 1) as u64; + let sdk_client = get_test_sdk_client().await; + let (bucket, prefix) = get_test_bucket_and_prefix("test_get_object"); + + let key = format!("{prefix}/test"); + let expected_body = vec![0x42; size]; + sdk_client + .put_object() + .bucket(&bucket) + .key(&key) + .body(ByteStream::from(expected_body.clone())) + .send() + .await + .unwrap(); + + let mut get_request = client + .get_object(&bucket, &key, Some(range.clone()), None) + .await + .expect("should not fail"); + + // Verify that we can receive the first part successfully + let first_part = get_request.next().await.expect("result should not be empty"); + let (offset, body) = first_part.unwrap(); + assert_eq!(offset, 0, "wrong body part offset"); + + let expected_range = range.start as usize..part_size; + assert_eq!(&body[..], &expected_body[expected_range]); + + // Overwrite the object + let new_content = vec![0xaa; size]; + sdk_client + .put_object() + .bucket(&bucket) + .key(&key) + .body(ByteStream::from(new_content)) + .send() + .await + .unwrap(); + + pin_mut!(get_request); + get_request.as_mut().increment_read_window(part_size); + + // Verify that the next part is error + let next = get_request.next().await.expect("result should not be empty"); + assert!(matches!( + next, + Err(ObjectClientError::ServiceError(GetObjectError::PreconditionFailed)) + )); } #[tokio::test]