Skip to content

Commit

Permalink
Allow enforcing SSE settings for the cache bucket
Browse files Browse the repository at this point in the history
Signed-off-by: Vlad Volodkin <vlaad@amazon.com>
  • Loading branch information
Vlad Volodkin committed Nov 13, 2024
1 parent 9206ed4 commit 62dca0f
Show file tree
Hide file tree
Showing 13 changed files with 331 additions and 24 deletions.
6 changes: 6 additions & 0 deletions mountpoint-s3-client/src/failure_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,12 @@ impl<Client: ObjectClient + Send + Sync, FailState: Send + Sync> GetObjectReques
self.request.get_object_checksum().await
}

async fn get_object_sse(
&self,
) -> ObjectClientResult<(Option<String>, Option<String>), GetObjectError, Self::ClientError> {
Ok((None, None))
}

fn increment_read_window(self: Pin<&mut Self>, len: usize) {
let this = self.project();
this.request.increment_read_window(len);
Expand Down
6 changes: 6 additions & 0 deletions mountpoint-s3-client/src/mock_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,12 @@ impl GetObjectRequest for MockGetObjectRequest {
Ok(self.object.checksum.clone())
}

async fn get_object_sse(
&self,
) -> ObjectClientResult<(Option<String>, Option<String>), GetObjectError, Self::ClientError> {
Ok((None, None))
}

fn increment_read_window(mut self: Pin<&mut Self>, len: usize) {
self.read_window_end_offset += len as u64;
}
Expand Down
6 changes: 6 additions & 0 deletions mountpoint-s3-client/src/mock_client/throughput_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,12 @@ impl GetObjectRequest for ThroughputGetObjectRequest {
Ok(self.request.object.checksum.clone())
}

async fn get_object_sse(
&self,
) -> ObjectClientResult<(Option<String>, Option<String>), GetObjectError, Self::ClientError> {
Ok((None, None))
}

fn increment_read_window(self: Pin<&mut Self>, len: usize) {
let this = self.project();
this.request.increment_read_window(len);
Expand Down
4 changes: 4 additions & 0 deletions mountpoint-s3-client/src/object_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,10 @@ pub trait GetObjectRequest:
/// Get the object's checksum, if uploaded with one
async fn get_object_checksum(&self) -> ObjectClientResult<Checksum, GetObjectError, Self::ClientError>;

async fn get_object_sse(
&self,
) -> ObjectClientResult<(Option<String>, Option<String>), GetObjectError, Self::ClientError>;

/// Increment the flow-control window, so that response data continues downloading.
///
/// If the client was created with `enable_read_backpressure` set true,
Expand Down
8 changes: 8 additions & 0 deletions mountpoint-s3-client/src/s3_crt_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1134,6 +1134,14 @@ fn parse_checksum(headers: &Headers) -> Result<Checksum, HeadersError> {
})
}

/// Extract the SSE-type and SSE-key information from headers
fn parse_object_sse(headers: &Headers) -> Result<(Option<String>, Option<String>), HeadersError> {
let sse_type = headers.get_as_optional_string("x-amz-server-side-encryption")?;
let sse_kms_key_id = headers.get_as_optional_string("x-amz-server-side-encryption-aws-kms-key-id")?;

Ok((sse_type, sse_kms_key_id))
}

/// Try to parse a modeled error out of a failing meta request
fn try_parse_generic_error(request_result: &MetaRequestResult) -> Option<S3RequestError> {
/// Look for a redirect header pointing to a different region for the bucket
Expand Down
11 changes: 10 additions & 1 deletion mountpoint-s3-client/src/s3_crt_client/get_object.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ use crate::object_client::{
Checksum, GetBodyPart, GetObjectError, GetObjectParams, ObjectClientError, ObjectClientResult, ObjectMetadata,
};
use crate::s3_crt_client::{
parse_checksum, GetObjectRequest, S3CrtClient, S3CrtClientInner, S3HttpRequest, S3Operation, S3RequestError,
parse_checksum, parse_object_sse, GetObjectRequest, S3CrtClient, S3CrtClientInner, S3HttpRequest, S3Operation,
S3RequestError,
};
use crate::types::ChecksumMode;

Expand Down Expand Up @@ -204,6 +205,14 @@ impl GetObjectRequest for S3GetObjectRequest {
parse_checksum(&headers).map_err(|e| ObjectClientError::ClientError(S3RequestError::InternalError(Box::new(e))))
}

async fn get_object_sse(
&self,
) -> ObjectClientResult<(Option<String>, Option<String>), GetObjectError, Self::ClientError> {
let headers = self.get_object_headers().await?;
parse_object_sse(&headers)
.map_err(|e| ObjectClientError::ClientError(S3RequestError::InternalError(Box::new(e))))
}

fn increment_read_window(mut self: Pin<&mut Self>, len: usize) {
self.read_window_end_offset += len as u64;
self.request.meta_request.increment_read_window(len as u64);
Expand Down
16 changes: 14 additions & 2 deletions mountpoint-s3/src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -894,7 +894,13 @@ where
match (args.disk_data_cache_config(), args.express_data_cache_config()) {
(None, Some((config, bucket_name, cache_bucket_name))) => {
tracing::trace!("using S3 Express One Zone bucket as a cache for object content");
let express_cache = ExpressDataCache::new(client.clone(), config, bucket_name, cache_bucket_name);
let express_cache = ExpressDataCache::new(
client.clone(),
config,
bucket_name,
cache_bucket_name,
filesystem_config.server_side_encryption.clone(),
);

let prefetcher = caching_prefetch(express_cache, runtime, prefetcher_config);
let fuse_session = create_filesystem(
Expand Down Expand Up @@ -933,7 +939,13 @@ where
(Some((disk_data_cache_config, cache_dir_path)), Some((config, bucket_name, cache_bucket_name))) => {
tracing::trace!("using both local disk and S3 Express One Zone bucket as a cache for object content");
let (managed_cache_dir, disk_cache) = create_disk_cache(cache_dir_path, disk_data_cache_config)?;
let express_cache = ExpressDataCache::new(client.clone(), config, bucket_name, cache_bucket_name);
let express_cache = ExpressDataCache::new(
client.clone(),
config,
bucket_name,
cache_bucket_name,
filesystem_config.server_side_encryption.clone(),
);
let cache = MultilevelDataCache::new(Arc::new(disk_cache), express_cache, runtime.clone());

let prefetcher = caching_prefetch(cache, runtime, prefetcher_config);
Expand Down
43 changes: 39 additions & 4 deletions mountpoint-s3/src/data_cache/express_data_cache.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::object::ObjectId;
use crate::ServerSideEncryption;

use super::{BlockIndex, ChecksummedBytes, DataCache, DataCacheError, DataCacheResult};

Expand Down Expand Up @@ -38,6 +39,7 @@ pub struct ExpressDataCache<Client: ObjectClient> {
config: ExpressDataCacheConfig,
/// Name of the S3 Express bucket to store the blocks.
bucket_name: String,
sse: ServerSideEncryption,
}

impl<S, C> From<ObjectClientError<S, C>> for DataCacheError
Expand All @@ -57,19 +59,27 @@ where
/// Create a new instance.
///
/// TODO: consider adding some validation of the bucket.
pub fn new(client: Client, config: ExpressDataCacheConfig, source_bucket_name: &str, bucket_name: &str) -> Self {
pub fn new(
client: Client,
config: ExpressDataCacheConfig,
source_bucket_name: &str,
bucket_name: &str,
sse: ServerSideEncryption,
) -> Self {
let prefix = hex::encode(
Sha256::new()
.chain_update(CACHE_VERSION.as_bytes())
.chain_update(config.block_size.to_be_bytes())
.chain_update(source_bucket_name.as_bytes())
.finalize(),
);

Self {
client,
prefix,
config,
bucket_name: bucket_name.to_owned(),
sse,
}
}
}
Expand Down Expand Up @@ -126,6 +136,15 @@ where
Err(e) => return Err(e.into()),
}
}

let (sse_type, sse_kms_key_id) = result
.get_object_sse()
.await
.map_err(|err| DataCacheError::IoFailure(err.into()))?; // TODO: avoid PUTing the block after this?
self.sse
.verify_response(sse_type.as_deref(), sse_kms_key_id.as_deref())
.map_err(|err| DataCacheError::IoFailure(err.into()))?;

let buffer = buffer.freeze();
DataCacheResult::Ok(Some(buffer.into()))
}
Expand All @@ -149,7 +168,16 @@ where
let object_key = block_key(&self.prefix, &cache_key, block_idx);

// TODO: ideally we should use a simple Put rather than MPU.
let params = PutObjectParams::new();
let mut params = PutObjectParams::new();
let (sse_type, key_id) = self
.sse
.clone()
.into_inner()
.map_err(|err| DataCacheError::IoFailure(err.into()))?;
// TODO: In the Zonal endpoint API calls (except CopyObject and UploadPartCopy), you can't override the values of the encryption settings (x-amz-server-side-encryption, x-amz-server-side-encryption-aws-kms-key-id, ...) from the CreateSession request.
params = params.server_side_encryption(sse_type);
params = params.ssekms_key_id(key_id);

let mut req = self
.client
.put_object(&self.bucket_name, &object_key, &params)
Expand All @@ -158,6 +186,7 @@ where
let (data, _crc) = bytes.into_inner().map_err(|_| DataCacheError::InvalidBlockContent)?;
req.write(&data).await?;
req.complete().await?;
// TODO: verify that headers of the PUT response match the expected SSE; what to do on error though?

DataCacheResult::Ok(())
}
Expand Down Expand Up @@ -206,7 +235,7 @@ mod tests {
block_size,
..Default::default()
};
let cache = ExpressDataCache::new(client, config, "unique source description", bucket);
let cache = ExpressDataCache::new(client, config, "unique source description", bucket, Default::default());

let data_1 = ChecksummedBytes::new("Foo".into());
let data_2 = ChecksummedBytes::new("Bar".into());
Expand Down Expand Up @@ -299,7 +328,13 @@ mod tests {
..Default::default()
};
let client = Arc::new(MockClient::new(config));
let cache = ExpressDataCache::new(client.clone(), Default::default(), "unique source description", bucket);
let cache = ExpressDataCache::new(
client.clone(),
Default::default(),
"unique source description",
bucket,
Default::default(),
);
let data_1 = vec![0u8; 1024 * 1024 + 1];
let data_1 = ChecksummedBytes::new(data_1.into());
let cache_key_1 = ObjectId::new("a".into(), ETag::for_tests());
Expand Down
8 changes: 7 additions & 1 deletion mountpoint-s3/src/data_cache/multilevel_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,13 @@ mod tests {
..Default::default()
};
let client = MockClient::new(config);
let cache = ExpressDataCache::new(client.clone(), Default::default(), "unique source description", bucket);
let cache = ExpressDataCache::new(
client.clone(),
Default::default(),
"unique source description",
bucket,
Default::default(),
);
(client, cache)
}

Expand Down
20 changes: 12 additions & 8 deletions mountpoint-s3/tests/common/fuse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ pub trait TestSessionCreator: FnOnce(&str, TestSessionConfig) -> TestSession {}
// `FnOnce(...)` in place of `impl TestSessionCreator`.
impl<T> TestSessionCreator for T where T: FnOnce(&str, TestSessionConfig) -> TestSession {}

fn create_fuse_session<Client, Prefetcher>(
pub fn create_fuse_session<Client, Prefetcher>(
client: Client,
prefetcher: Prefetcher,
bucket: &str,
Expand Down Expand Up @@ -363,12 +363,7 @@ pub mod s3_session {
let (bucket, prefix) = get_test_bucket_and_prefix(test_name);
let region = get_test_region();

let client_config = S3ClientConfig::default()
.part_size(test_config.part_size)
.endpoint_config(get_test_endpoint_config())
.read_backpressure(true)
.initial_read_window(test_config.initial_read_window_size);
let client = S3CrtClient::new(client_config).unwrap();
let client = create_crt_client(test_config.part_size, test_config.initial_read_window_size);
let runtime = client.event_loop_group();
let prefetcher = caching_prefetch(cache, runtime, test_config.prefetcher_config);
let session = create_fuse_session(
Expand All @@ -385,7 +380,16 @@ pub mod s3_session {
}
}

fn create_test_client(region: &str, bucket: &str, prefix: &str) -> impl TestClient {
pub fn create_crt_client(part_size: usize, initial_read_window_size: usize) -> S3CrtClient {
let client_config = S3ClientConfig::default()
.part_size(part_size)
.endpoint_config(get_test_endpoint_config())
.read_backpressure(true)
.initial_read_window(initial_read_window_size);
S3CrtClient::new(client_config).unwrap()
}

pub fn create_test_client(region: &str, bucket: &str, prefix: &str) -> impl TestClient {
let sdk_client = tokio_block_on(async { get_test_sdk_client(region).await });
SDKTestClient {
prefix: prefix.to_owned(),
Expand Down
28 changes: 20 additions & 8 deletions mountpoint-s3/tests/common/s3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,35 @@ use rand_chacha::rand_core::OsRng;
use crate::common::tokio_block_on;

pub fn get_test_bucket_and_prefix(test_name: &str) -> (String, String) {
let bucket = if cfg!(feature = "s3express_tests") {
std::env::var("S3_EXPRESS_ONE_ZONE_BUCKET_NAME")
.expect("Set S3_EXPRESS_ONE_ZONE_BUCKET_NAME to run integration tests")
} else {
std::env::var("S3_BUCKET_NAME").expect("Set S3_BUCKET_NAME to run integration tests")
};
#[cfg(not(feature = "s3express_tests"))]
let bucket = get_standard_bucket();
#[cfg(feature = "s3express_tests")]
let bucket = get_express_bucket();

let prefix = get_test_prefix(test_name);

(bucket, prefix)
}

pub fn get_test_prefix(test_name: &str) -> String {
// Generate a random nonce to make sure this prefix is truly unique
let nonce = OsRng.next_u64();

// Prefix always has a trailing "/" to keep meaning in sync with the S3 API.
let prefix = std::env::var("S3_BUCKET_TEST_PREFIX").unwrap_or(String::from("mountpoint-test/"));
assert!(prefix.ends_with('/'), "S3_BUCKET_TEST_PREFIX should end in '/'");

let prefix = format!("{prefix}{test_name}/{nonce}/");
format!("{prefix}{test_name}/{nonce}/")
}

(bucket, prefix)
#[cfg(feature = "s3express_tests")]
pub fn get_express_bucket() -> String {
std::env::var("S3_EXPRESS_ONE_ZONE_BUCKET_NAME")
.expect("Set S3_EXPRESS_ONE_ZONE_BUCKET_NAME to run integration tests")
}

pub fn get_standard_bucket() -> String {
std::env::var("S3_BUCKET_NAME").expect("Set S3_BUCKET_NAME to run integration tests")
}

pub fn get_test_bucket_forbidden() -> String {
Expand Down
Loading

0 comments on commit 62dca0f

Please sign in to comment.