diff --git a/crates/spfs/src/storage/rpc/mod.rs b/crates/spfs/src/storage/rpc/mod.rs index 1ec2448d3..20aead7b7 100644 --- a/crates/spfs/src/storage/rpc/mod.rs +++ b/crates/spfs/src/storage/rpc/mod.rs @@ -9,4 +9,4 @@ mod payload; mod repository; mod tag; -pub use repository::{Config, RpcRepository}; +pub use repository::{Config, Params, RpcRepository}; diff --git a/crates/spfs/src/storage/rpc/repository.rs b/crates/spfs/src/storage/rpc/repository.rs index 0f4c8054f..79cfd84a5 100644 --- a/crates/spfs/src/storage/rpc/repository.rs +++ b/crates/spfs/src/storage/rpc/repository.rs @@ -25,6 +25,21 @@ pub struct Params { /// if true, don't actually attempt to connect until first use #[serde(default)] pub lazy: bool, + + /// The global timeout for all requests made in this client + /// + /// Default is no timeout + pub timeout_ms: Option, + + /// Maximum message size that the client will accept from the server + /// + /// Default is 4 Mb + pub max_decode_message_size_bytes: Option, + + /// Maximum message size that the client will sent to the server + /// + /// Default is no limit + pub max_encode_message_size_bytes: Option, } #[async_trait::async_trait] @@ -84,19 +99,34 @@ impl RpcRepository { /// Create a new rpc repository client for the given configuration pub async fn new(config: Config) -> OpenRepositoryResult { - let endpoint = tonic::transport::Endpoint::from_shared(config.address.to_string()) + let mut endpoint = tonic::transport::Endpoint::from_shared(config.address.to_string()) .map_err(|source| OpenRepositoryError::InvalidTransportAddress { address: config.address.to_string(), source, })?; + if let Some(ms) = config.params.timeout_ms { + endpoint = endpoint.timeout(std::time::Duration::from_millis(ms)); + } let channel = match config.params.lazy { true => endpoint.connect_lazy(), false => endpoint.connect().await?, }; - let repo_client = RepositoryClient::new(channel.clone()); - let tag_client = TagServiceClient::new(channel.clone()); - let db_client = DatabaseServiceClient::new(channel.clone()); - let payload_client = PayloadServiceClient::new(channel); + let mut repo_client = RepositoryClient::new(channel.clone()); + let mut tag_client = TagServiceClient::new(channel.clone()); + let mut db_client = DatabaseServiceClient::new(channel.clone()); + let mut payload_client = PayloadServiceClient::new(channel); + if let Some(max) = config.params.max_decode_message_size_bytes { + repo_client = repo_client.max_decoding_message_size(max); + tag_client = tag_client.max_decoding_message_size(max); + db_client = db_client.max_decoding_message_size(max); + payload_client = payload_client.max_decoding_message_size(max); + } + if let Some(max) = config.params.max_encode_message_size_bytes { + repo_client = repo_client.max_encoding_message_size(max); + tag_client = tag_client.max_encoding_message_size(max); + db_client = db_client.max_encoding_message_size(max); + payload_client = payload_client.max_encoding_message_size(max); + } Ok(Self { address: config.to_address().expect("an internally valid config"), repo_client,