Skip to content

Commit

Permalink
do pre-check in validate fn
Browse files Browse the repository at this point in the history
  • Loading branch information
yjhmelody committed Jun 24, 2024
1 parent 4cb0bd4 commit 7cc9fb2
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 47 deletions.
43 changes: 42 additions & 1 deletion src/config/mod.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
use anyhow::Context;
use anyhow::{bail, Context};
use regex::{Captures, Regex};
use std::env;
use std::fs;
use std::num::NonZeroU32;
use std::path;
use std::time::Duration;

use garde::Validate;
use governor::RateLimiter;
use serde::Deserialize;

use crate::extensions::rate_limit::build_quota;
use crate::extensions::ExtensionsConfig;
pub use rpc::*;

Expand Down Expand Up @@ -208,6 +212,43 @@ pub async fn validate(config: &Config) -> Result<(), anyhow::Error> {

// validate use garde::Validate
config.validate(&())?;

if let Some(rate_limit) = config.extensions.rate_limit.as_ref() {
if let Some(ref rule) = rate_limit.ip {
let burst = NonZeroU32::new(rule.burst).unwrap();
let period = Duration::from_secs(rule.period_secs);
let quota = build_quota(burst, period);
let limiter = RateLimiter::direct(quota);

for method in &config.rpcs.methods {
if let Some(n) = NonZeroU32::new(method.rate_limit_weight) {
if limiter.check_n(n).is_err() {
bail!("`{}` weight config too big for ip rate limit: {}", method.method, n);
}
}
}
}

if let Some(ref rule) = rate_limit.connection {
let burst = NonZeroU32::new(rule.burst).unwrap();
let period = Duration::from_secs(rule.period_secs);
let quota = build_quota(burst, period);
let limiter = RateLimiter::direct(quota);

for method in &config.rpcs.methods {
if let Some(n) = NonZeroU32::new(method.rate_limit_weight) {
if limiter.check_n(n).is_err() {
bail!(
"`{}` weight config too big for connection rate limit: {}",
method.method,
n
);
}
}
}
}
}

// since endpoints connection test is async
// we can't intergrate it into garde::Validate
// and it's not a static validation like format, length, .etc
Expand Down
39 changes: 0 additions & 39 deletions src/extensions/rate_limit/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use anyhow::bail;
use governor::{DefaultKeyedRateLimiter, Jitter, Quota, RateLimiter};
use serde::Deserialize;
use std::num::NonZeroU32;
Expand Down Expand Up @@ -92,44 +91,6 @@ impl RateLimitBuilder {
}
}

pub fn pre_check_connection(&self, method_weights: &MethodWeights) -> anyhow::Result<()> {
if let Some(ref rule) = self.config.connection {
let burst = NonZeroU32::new(rule.burst).unwrap();
let period = Duration::from_secs(rule.period_secs);
let quota = build_quota(burst, period);
let limiter = RateLimiter::direct(quota);

for (method, weight) in method_weights.0.as_ref() {
if let Some(n) = NonZeroU32::new(*weight) {
if limiter.check_n(n).is_err() {
bail!("`{method}` weight config too big for connection rate limit: {}", n);
}
}
}
}

Ok(())
}

pub fn pre_check_ip(&self, method_weights: &MethodWeights) -> anyhow::Result<()> {
if let Some(ref rule) = self.config.ip {
let burst = NonZeroU32::new(rule.burst).unwrap();
let period = Duration::from_secs(rule.period_secs);
let quota = build_quota(burst, period);
let limiter = RateLimiter::direct(quota);

for (method, weight) in method_weights.0.as_ref() {
if let Some(n) = NonZeroU32::new(*weight) {
if limiter.check_n(n).is_err() {
bail!("`{method}` weight config too big for ip rate limit: {}", n);
}
}
}
}

Ok(())
}

pub fn connection_limit(&self, method_weights: MethodWeights) -> Option<ConnectionRateLimitLayer> {
if let Some(ref rule) = self.config.connection {
let burst = NonZeroU32::new(rule.burst).unwrap();
Expand Down
2 changes: 1 addition & 1 deletion src/extensions/rate_limit/weight.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::collections::BTreeMap;
use std::sync::Arc;

#[derive(Clone, Debug, Default)]
pub struct MethodWeights(pub(crate) Arc<BTreeMap<String, u32>>);
pub struct MethodWeights(Arc<BTreeMap<String, u32>>);

impl MethodWeights {
pub fn get(&self, method: &str) -> u32 {
Expand Down
6 changes: 0 additions & 6 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,6 @@ pub async fn build(config: Config) -> anyhow::Result<SubwayServerHandle> {

let rpc_method_weights = MethodWeights::from_config(&config.rpcs.methods);

// pre-check stage
if let Some(r) = &rate_limit_builder {
r.pre_check_ip(&rpc_method_weights)?;
r.pre_check_connection(&rpc_method_weights)?;
}

let request_timeout_seconds = server_builder.config.request_timeout_seconds;

let metrics = get_rpc_metrics(&extensions_registry).await;
Expand Down

0 comments on commit 7cc9fb2

Please sign in to comment.