Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow enforcing SSE settings for the cache bucket #1131

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions mountpoint-s3-client/src/failure_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,12 @@ impl<Client: ObjectClient + Send + Sync, FailState: Send + Sync> GetObjectReques
self.request.get_object_checksum().await
}

async fn get_object_sse(
&self,
) -> ObjectClientResult<(Option<String>, Option<String>), GetObjectError, Self::ClientError> {
Ok((None, None))
}

fn increment_read_window(self: Pin<&mut Self>, len: usize) {
let this = self.project();
this.request.increment_read_window(len);
Expand Down
6 changes: 6 additions & 0 deletions mountpoint-s3-client/src/mock_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,12 @@ impl GetObjectRequest for MockGetObjectRequest {
Ok(self.object.checksum.clone())
}

async fn get_object_sse(
&self,
) -> ObjectClientResult<(Option<String>, Option<String>), GetObjectError, Self::ClientError> {
Ok((None, None))
}

fn increment_read_window(mut self: Pin<&mut Self>, len: usize) {
self.read_window_end_offset += len as u64;
}
Expand Down
6 changes: 6 additions & 0 deletions mountpoint-s3-client/src/mock_client/throughput_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,12 @@ impl GetObjectRequest for ThroughputGetObjectRequest {
Ok(self.request.object.checksum.clone())
}

async fn get_object_sse(
&self,
) -> ObjectClientResult<(Option<String>, Option<String>), GetObjectError, Self::ClientError> {
Ok((None, None))
}

fn increment_read_window(self: Pin<&mut Self>, len: usize) {
let this = self.project();
this.request.increment_read_window(len);
Expand Down
6 changes: 6 additions & 0 deletions mountpoint-s3-client/src/object_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,12 @@ pub trait GetObjectRequest:
/// Get the object's checksum, if uploaded with one
async fn get_object_checksum(&self) -> ObjectClientResult<Checksum, GetObjectError, Self::ClientError>;

/// Get the object's SSE type and KMS Key ARN, as defined in:
/// https://docs.aws.amazon.com/kms/latest/developerguide/concepts.html#key-id
async fn get_object_sse(
&self,
) -> ObjectClientResult<(Option<String>, Option<String>), GetObjectError, Self::ClientError>;

/// Increment the flow-control window, so that response data continues downloading.
///
/// If the client was created with `enable_read_backpressure` set true,
Expand Down
8 changes: 8 additions & 0 deletions mountpoint-s3-client/src/s3_crt_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1134,6 +1134,14 @@ fn parse_checksum(headers: &Headers) -> Result<Checksum, HeadersError> {
})
}

/// Extract the SSE-type and SSE-key information from headers
fn parse_object_sse(headers: &Headers) -> Result<(Option<String>, Option<String>), HeadersError> {
let sse_type = headers.get_as_optional_string("x-amz-server-side-encryption")?;
let sse_kms_key_id = headers.get_as_optional_string("x-amz-server-side-encryption-aws-kms-key-id")?;

Ok((sse_type, sse_kms_key_id))
}

/// Try to parse a modeled error out of a failing meta request
fn try_parse_generic_error(request_result: &MetaRequestResult) -> Option<S3RequestError> {
/// Look for a redirect header pointing to a different region for the bucket
Expand Down
11 changes: 10 additions & 1 deletion mountpoint-s3-client/src/s3_crt_client/get_object.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ use crate::object_client::{
Checksum, GetBodyPart, GetObjectError, GetObjectParams, ObjectClientError, ObjectClientResult, ObjectMetadata,
};
use crate::s3_crt_client::{
parse_checksum, GetObjectRequest, S3CrtClient, S3CrtClientInner, S3HttpRequest, S3Operation, S3RequestError,
parse_checksum, parse_object_sse, GetObjectRequest, S3CrtClient, S3CrtClientInner, S3HttpRequest, S3Operation,
S3RequestError,
};
use crate::types::ChecksumMode;

Expand Down Expand Up @@ -204,6 +205,14 @@ impl GetObjectRequest for S3GetObjectRequest {
parse_checksum(&headers).map_err(|e| ObjectClientError::ClientError(S3RequestError::InternalError(Box::new(e))))
}

async fn get_object_sse(
&self,
) -> ObjectClientResult<(Option<String>, Option<String>), GetObjectError, Self::ClientError> {
let headers = self.get_object_headers().await?;
parse_object_sse(&headers)
.map_err(|e| ObjectClientError::ClientError(S3RequestError::InternalError(Box::new(e))))
}

fn increment_read_window(mut self: Pin<&mut Self>, len: usize) {
self.read_window_end_offset += len as u64;
self.request.meta_request.increment_read_window(len as u64);
Expand Down
14 changes: 14 additions & 0 deletions mountpoint-s3-client/tests/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,20 @@ pub fn get_test_bucket() -> String {
}
}

/// An S3 Express bucket with SSE-S3 set as the default encryption
#[cfg(feature = "s3express_tests")]
pub fn get_express_bucket() -> String {
std::env::var("S3_EXPRESS_ONE_ZONE_BUCKET_NAME")
.expect("Set S3_EXPRESS_ONE_ZONE_BUCKET_NAME to run integration tests")
}

/// An S3 Express bucket with SSE-KMS set as a default encryption with a key matching the `KMS_TEST_KEY_ID`
#[cfg(feature = "s3express_tests")]
pub fn get_express_sse_kms_bucket() -> String {
std::env::var("S3_EXPRESS_ONE_ZONE_BUCKET_NAME_SSE_KMS")
.expect("Set S3_EXPRESS_ONE_ZONE_BUCKET_NAME_SSE_KMS to run integration tests")
}

pub fn get_test_kms_key_id() -> String {
std::env::var("KMS_TEST_KEY_ID").expect("Set KMS_TEST_KEY_ID to run integration tests")
}
Expand Down
56 changes: 56 additions & 0 deletions mountpoint-s3-client/tests/get_object.rs
Original file line number Diff line number Diff line change
Expand Up @@ -551,3 +551,59 @@ async fn test_get_object_checksum_checksums_disabled() {
.await
.expect_err("should not return a checksum object as not requested");
}

#[test_case("aws:kms", Some(get_test_kms_key_id()))]
#[test_case("aws:kms:dsse", Some(get_test_kms_key_id()))]
#[test_case("AES256", None)]
#[tokio::test]
#[cfg(not(feature = "s3express_tests"))]
async fn test_get_object_sse(sse_type: &str, kms_key_id: Option<String>) {
test_get_object_sse_base(sse_type, kms_key_id, get_test_bucket()).await;
}

// We have a separate set of tests for express because:
// 1. via SDK / CRT we can only put an object with the settings that match bucket's defaults:
// https://docs.aws.amazon.com/AmazonS3/latest/userguide/s3-express-specifying-kms-encryption.html
// 2. Express doesn't currently support `aws:kms:dsse`.
#[test_case("AES256", None, get_express_bucket())]
#[test_case("aws:kms", Some(get_test_kms_key_id()), get_express_sse_kms_bucket())]
#[tokio::test]
#[cfg(feature = "s3express_tests")]
async fn test_get_object_sse(sse_type: &str, kms_key_id: Option<String>, bucket: String) {
test_get_object_sse_base(sse_type, kms_key_id, bucket).await;
}

async fn test_get_object_sse_base(sse_type: &str, kms_key_id: Option<String>, bucket: String) {
let sdk_client = get_test_sdk_client().await;
let prefix = get_unique_test_prefix("test_get_object_sse");

let key = format!("{prefix}/test");
let body = vec![0x42; 42];
let mut request = sdk_client
.put_object()
.bucket(&bucket)
.key(&key)
.body(ByteStream::from(body.clone()))
.server_side_encryption(sse_type.into());
if let Some(kms_key_id) = kms_key_id.as_ref() {
request = request.ssekms_key_id(kms_key_id)
}
request.send().await.unwrap();

let client: S3CrtClient = get_test_client();

let result = client
.get_object(
&bucket,
&key,
&GetObjectParams::new().checksum_mode(Some(ChecksumMode::Enabled)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need checksums enabled here right?

)
.await
.expect("get_object should succeed");

let (received_sse, received_key_id) = result.get_object_sse().await.expect("should return sse settings");
assert_eq!(received_sse.expect("sse type must be some"), sse_type);
if let Some(kms_key_id) = kms_key_id {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's None, we should assert that there's no key received, or it's S3 SSE

assert_eq!(received_key_id.expect("sse key must be some"), kms_key_id);
}
}
47 changes: 46 additions & 1 deletion mountpoint-s3-client/tests/put_object_single.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,6 @@ async fn test_put_object_storage_class(storage_class: &str) {
assert_eq!(storage_class, attributes.storage_class.unwrap().as_str());
}

#[cfg(not(feature = "s3express_tests"))]
async fn check_sse(
bucket: &String,
key: &String,
Expand Down Expand Up @@ -292,6 +291,52 @@ async fn test_put_object_sse(sse_type: Option<&str>, kms_key_id: Option<String>)
check_sse(&bucket, &key, sse_type, &kms_key_id, put_object_result).await;
}

#[test_case(Some("aws:kms"), Some(get_test_kms_key_id()), get_express_sse_kms_bucket(), false)]
#[test_case(Some("aws:kms"), None, get_express_sse_kms_bucket(), false)]
#[test_case(Some("aws:kms"), Some(get_test_kms_key_id()), get_express_bucket(), true)]
#[test_case(Some("aws:kms"), None, get_express_bucket(), true)]
#[test_case(Some("AES256"), None, get_express_bucket(), false)]
#[test_case(Some("AES256"), None, get_express_sse_kms_bucket(), true)]
#[test_case(None, None, get_express_bucket(), false)]
#[tokio::test]
#[cfg(feature = "s3express_tests")]
async fn test_put_object_sse(sse_type: Option<&str>, kms_key_id: Option<String>, bucket: String, should_fail: bool) {
let client_config = S3ClientConfig::new().endpoint_config(get_test_endpoint_config());
let client = S3CrtClient::new(client_config).expect("could not create test client");
let request_params = PutObjectSingleParams::new()
.server_side_encryption(sse_type.map(|value| value.to_owned()))
.ssekms_key_id(kms_key_id.to_owned());

// Make a request
let prefix = get_unique_test_prefix("test_put_object_sse");
let key = format!("{prefix}hello");

let mut rng = rand::thread_rng();
let mut contents = vec![0u8; 32];
rng.fill(&mut contents[..]);

let put_object_result = client
.put_object_single(&bucket, &key, &request_params, &contents)
.await;

if should_fail {
assert!(put_object_result.is_err(), "put request should fail");
return;
} else {
assert!(put_object_result.is_ok(), "put request should succeed");
}

// Check sse of the object via SDK and the values returned in response to the PUT
check_sse(
&bucket,
&key,
sse_type,
&kms_key_id,
put_object_result.expect("put request should succeed"),
)
.await;
}

#[tokio::test]
async fn test_put_object_header() {
let (bucket, prefix) = get_test_bucket_and_prefix("test_put_object_header");
Expand Down
16 changes: 14 additions & 2 deletions mountpoint-s3/src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -894,7 +894,13 @@ where
match (args.disk_data_cache_config(), args.express_data_cache_config()) {
(None, Some((config, bucket_name, cache_bucket_name))) => {
tracing::trace!("using S3 Express One Zone bucket as a cache for object content");
let express_cache = ExpressDataCache::new(client.clone(), config, bucket_name, cache_bucket_name);
let express_cache = ExpressDataCache::new(
client.clone(),
config,
bucket_name,
cache_bucket_name,
filesystem_config.server_side_encryption.clone(),
);

let prefetcher = caching_prefetch(express_cache, runtime, prefetcher_config);
let fuse_session = create_filesystem(
Expand Down Expand Up @@ -933,7 +939,13 @@ where
(Some((disk_data_cache_config, cache_dir_path)), Some((config, bucket_name, cache_bucket_name))) => {
tracing::trace!("using both local disk and S3 Express One Zone bucket as a cache for object content");
let (managed_cache_dir, disk_cache) = create_disk_cache(cache_dir_path, disk_data_cache_config)?;
let express_cache = ExpressDataCache::new(client.clone(), config, bucket_name, cache_bucket_name);
let express_cache = ExpressDataCache::new(
client.clone(),
config,
bucket_name,
cache_bucket_name,
filesystem_config.server_side_encryption.clone(),
);
let cache = MultilevelDataCache::new(Arc::new(disk_cache), express_cache, runtime.clone());

let prefetcher = caching_prefetch(cache, runtime, prefetcher_config);
Expand Down
Loading
Loading