From 7cc9fb24bd3f4decaf8708956d338d748389d569 Mon Sep 17 00:00:00 2001 From: yjhmelody Date: Mon, 24 Jun 2024 18:13:32 +0800 Subject: [PATCH] do pre-check in `validate` fn --- src/config/mod.rs | 43 ++++++++++++++++++++++++++++- src/extensions/rate_limit/mod.rs | 39 -------------------------- src/extensions/rate_limit/weight.rs | 2 +- src/server.rs | 6 ---- 4 files changed, 43 insertions(+), 47 deletions(-) diff --git a/src/config/mod.rs b/src/config/mod.rs index 5424157..2d62cc2 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -1,12 +1,16 @@ -use anyhow::Context; +use anyhow::{bail, Context}; use regex::{Captures, Regex}; use std::env; use std::fs; +use std::num::NonZeroU32; use std::path; +use std::time::Duration; use garde::Validate; +use governor::RateLimiter; use serde::Deserialize; +use crate::extensions::rate_limit::build_quota; use crate::extensions::ExtensionsConfig; pub use rpc::*; @@ -208,6 +212,43 @@ pub async fn validate(config: &Config) -> Result<(), anyhow::Error> { // validate use garde::Validate config.validate(&())?; + + if let Some(rate_limit) = config.extensions.rate_limit.as_ref() { + if let Some(ref rule) = rate_limit.ip { + let burst = NonZeroU32::new(rule.burst).unwrap(); + let period = Duration::from_secs(rule.period_secs); + let quota = build_quota(burst, period); + let limiter = RateLimiter::direct(quota); + + for method in &config.rpcs.methods { + if let Some(n) = NonZeroU32::new(method.rate_limit_weight) { + if limiter.check_n(n).is_err() { + bail!("`{}` weight config too big for ip rate limit: {}", method.method, n); + } + } + } + } + + if let Some(ref rule) = rate_limit.connection { + let burst = NonZeroU32::new(rule.burst).unwrap(); + let period = Duration::from_secs(rule.period_secs); + let quota = build_quota(burst, period); + let limiter = RateLimiter::direct(quota); + + for method in &config.rpcs.methods { + if let Some(n) = NonZeroU32::new(method.rate_limit_weight) { + if limiter.check_n(n).is_err() { + bail!( + "`{}` weight config too big for connection rate limit: {}", + method.method, + n + ); + } + } + } + } + } + // since endpoints connection test is async // we can't intergrate it into garde::Validate // and it's not a static validation like format, length, .etc diff --git a/src/extensions/rate_limit/mod.rs b/src/extensions/rate_limit/mod.rs index dc76806..61dd2a1 100644 --- a/src/extensions/rate_limit/mod.rs +++ b/src/extensions/rate_limit/mod.rs @@ -1,4 +1,3 @@ -use anyhow::bail; use governor::{DefaultKeyedRateLimiter, Jitter, Quota, RateLimiter}; use serde::Deserialize; use std::num::NonZeroU32; @@ -92,44 +91,6 @@ impl RateLimitBuilder { } } - pub fn pre_check_connection(&self, method_weights: &MethodWeights) -> anyhow::Result<()> { - if let Some(ref rule) = self.config.connection { - let burst = NonZeroU32::new(rule.burst).unwrap(); - let period = Duration::from_secs(rule.period_secs); - let quota = build_quota(burst, period); - let limiter = RateLimiter::direct(quota); - - for (method, weight) in method_weights.0.as_ref() { - if let Some(n) = NonZeroU32::new(*weight) { - if limiter.check_n(n).is_err() { - bail!("`{method}` weight config too big for connection rate limit: {}", n); - } - } - } - } - - Ok(()) - } - - pub fn pre_check_ip(&self, method_weights: &MethodWeights) -> anyhow::Result<()> { - if let Some(ref rule) = self.config.ip { - let burst = NonZeroU32::new(rule.burst).unwrap(); - let period = Duration::from_secs(rule.period_secs); - let quota = build_quota(burst, period); - let limiter = RateLimiter::direct(quota); - - for (method, weight) in method_weights.0.as_ref() { - if let Some(n) = NonZeroU32::new(*weight) { - if limiter.check_n(n).is_err() { - bail!("`{method}` weight config too big for ip rate limit: {}", n); - } - } - } - } - - Ok(()) - } - pub fn connection_limit(&self, method_weights: MethodWeights) -> Option { if let Some(ref rule) = self.config.connection { let burst = NonZeroU32::new(rule.burst).unwrap(); diff --git a/src/extensions/rate_limit/weight.rs b/src/extensions/rate_limit/weight.rs index e5829cc..06fedd0 100644 --- a/src/extensions/rate_limit/weight.rs +++ b/src/extensions/rate_limit/weight.rs @@ -3,7 +3,7 @@ use std::collections::BTreeMap; use std::sync::Arc; #[derive(Clone, Debug, Default)] -pub struct MethodWeights(pub(crate) Arc>); +pub struct MethodWeights(Arc>); impl MethodWeights { pub fn get(&self, method: &str) -> u32 { diff --git a/src/server.rs b/src/server.rs index e434975..6bf2f0c 100644 --- a/src/server.rs +++ b/src/server.rs @@ -51,12 +51,6 @@ pub async fn build(config: Config) -> anyhow::Result { let rpc_method_weights = MethodWeights::from_config(&config.rpcs.methods); - // pre-check stage - if let Some(r) = &rate_limit_builder { - r.pre_check_ip(&rpc_method_weights)?; - r.pre_check_connection(&rpc_method_weights)?; - } - let request_timeout_seconds = server_builder.config.request_timeout_seconds; let metrics = get_rpc_metrics(&extensions_registry).await;