Skip to content
This repository has been archived by the owner on Sep 10, 2024. It is now read-only.

Commit

Permalink
Add rate-limiting for account recovery and registration (#3093)
Browse files Browse the repository at this point in the history
* Add rate-limiting for account recovery and registration

* Rename login ratelimiter `per_address` to `per_ip` for consistency

Co-authored-by: Quentin Gliech <quenting@element.io>
  • Loading branch information
reivilibre and sandhose authored Aug 7, 2024
1 parent 244f8f5 commit 5d4a4a6
Show file tree
Hide file tree
Showing 10 changed files with 320 additions and 35 deletions.
94 changes: 86 additions & 8 deletions crates/config/src/sections/rate_limiting.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,28 @@ use crate::ConfigurationSection;
/// Configuration related to sending emails
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
pub struct RateLimitingConfig {
/// Account Recovery-specific rate limits
#[serde(default)]
pub account_recovery: AccountRecoveryRateLimitingConfig,
/// Login-specific rate limits
#[serde(default)]
pub login: LoginRateLimitingConfig,
/// Controls how many registrations attempts are permitted
/// based on source address.
#[serde(default = "default_registration")]
pub registration: RateLimiterConfiguration,
}

#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
pub struct LoginRateLimitingConfig {
/// Controls how many login attempts are permitted
/// based on source address.
/// based on source IP address.
/// This can protect against brute force login attempts.
///
/// Note: this limit also applies to password checks when a user attempts to
/// change their own password.
#[serde(default = "default_login_per_address")]
pub per_address: RateLimiterConfiguration,
#[serde(default = "default_login_per_ip")]
pub per_ip: RateLimiterConfiguration,
/// Controls how many login attempts are permitted
/// based on the account that is being attempted to be logged into.
/// This can protect against a distributed brute force attack
Expand All @@ -50,6 +57,24 @@ pub struct LoginRateLimitingConfig {
pub per_account: RateLimiterConfiguration,
}

#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
pub struct AccountRecoveryRateLimitingConfig {
/// Controls how many account recovery attempts are permitted
/// based on source IP address.
/// This can protect against causing e-mail spam to many targets.
///
/// Note: this limit also applies to re-sends.
#[serde(default = "default_account_recovery_per_ip")]
pub per_ip: RateLimiterConfiguration,
/// Controls how many account recovery attempts are permitted
/// based on the e-mail address entered into the recovery form.
/// This can protect against causing e-mail spam to one target.
///
/// Note: this limit also applies to re-sends.
#[serde(default = "default_account_recovery_per_address")]
pub per_address: RateLimiterConfiguration,
}

#[derive(Copy, Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
pub struct RateLimiterConfiguration {
/// A one-off burst of actions that the user can perform
Expand All @@ -66,6 +91,13 @@ impl ConfigurationSection for RateLimitingConfig {
fn validate(&self, figment: &figment::Figment) -> Result<(), figment::Error> {
let metadata = figment.find_metadata(Self::PATH.unwrap());

let error_on_field = |mut error: figment::error::Error, field: &'static str| {
error.metadata = metadata.cloned();
error.profile = Some(figment::Profile::Default);
error.path = vec![Self::PATH.unwrap().to_owned(), field.to_owned()];
error
};

let error_on_nested_field =
|mut error: figment::error::Error, container: &'static str, field: &'static str| {
error.metadata = metadata.cloned();
Expand All @@ -92,8 +124,23 @@ impl ConfigurationSection for RateLimitingConfig {
None
};

if let Some(error) = error_on_limiter(&self.login.per_address) {
return Err(error_on_nested_field(error, "login", "per_address"));
if let Some(error) = error_on_limiter(&self.account_recovery.per_ip) {
return Err(error_on_nested_field(error, "account_recovery", "per_ip"));
}
if let Some(error) = error_on_limiter(&self.account_recovery.per_address) {
return Err(error_on_nested_field(
error,
"account_recovery",
"per_address",
));
}

if let Some(error) = error_on_limiter(&self.registration) {
return Err(error_on_field(error, "registration"));
}

if let Some(error) = error_on_limiter(&self.login.per_ip) {
return Err(error_on_nested_field(error, "login", "per_ip"));
}
if let Some(error) = error_on_limiter(&self.login.per_account) {
return Err(error_on_nested_field(error, "login", "per_account"));
Expand All @@ -119,7 +166,7 @@ impl RateLimiterConfiguration {
}
}

fn default_login_per_address() -> RateLimiterConfiguration {
fn default_login_per_ip() -> RateLimiterConfiguration {
RateLimiterConfiguration {
burst: NonZeroU32::new(3).unwrap(),
per_second: 3.0 / 60.0,
Expand All @@ -133,20 +180,51 @@ fn default_login_per_account() -> RateLimiterConfiguration {
}
}

#[allow(clippy::derivable_impls)] // when we add some top-level ratelimiters this will not be derivable anymore
fn default_registration() -> RateLimiterConfiguration {
RateLimiterConfiguration {
burst: NonZeroU32::new(3).unwrap(),
per_second: 3.0 / 3600.0,
}
}

fn default_account_recovery_per_ip() -> RateLimiterConfiguration {
RateLimiterConfiguration {
burst: NonZeroU32::new(3).unwrap(),
per_second: 3.0 / 3600.0,
}
}

fn default_account_recovery_per_address() -> RateLimiterConfiguration {
RateLimiterConfiguration {
burst: NonZeroU32::new(3).unwrap(),
per_second: 1.0 / 3600.0,
}
}

impl Default for RateLimitingConfig {
fn default() -> Self {
RateLimitingConfig {
login: LoginRateLimitingConfig::default(),
registration: default_registration(),
account_recovery: AccountRecoveryRateLimitingConfig::default(),
}
}
}

impl Default for LoginRateLimitingConfig {
fn default() -> Self {
LoginRateLimitingConfig {
per_address: default_login_per_address(),
per_ip: default_login_per_ip(),
per_account: default_login_per_account(),
}
}
}

impl Default for AccountRecoveryRateLimitingConfig {
fn default() -> Self {
AccountRecoveryRateLimitingConfig {
per_ip: default_account_recovery_per_ip(),
per_address: default_account_recovery_per_address(),
}
}
}
74 changes: 73 additions & 1 deletion crates/handlers/src/rate_limit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,15 @@ use mas_config::RateLimitingConfig;
use mas_data_model::User;
use ulid::Ulid;

#[derive(Debug, Clone, thiserror::Error)]
pub enum AccountRecoveryLimitedError {
#[error("Too many account recovery requests for requester {0}")]
Requester(RequesterFingerprint),

#[error("Too many account recovery requests for e-mail {0}")]
Email(String),
}

#[derive(Debug, Clone, Copy, thiserror::Error)]
pub enum PasswordCheckLimitedError {
#[error("Too many password checks for requester {0}")]
Expand All @@ -28,6 +37,12 @@ pub enum PasswordCheckLimitedError {
User(Ulid),
}

#[derive(Debug, Clone, thiserror::Error)]
pub enum RegistrationLimitedError {
#[error("Too many account registration requests for requester {0}")]
Requester(RequesterFingerprint),
}

/// Key used to rate limit requests per requester
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct RequesterFingerprint {
Expand Down Expand Up @@ -66,15 +81,25 @@ type KeyedRateLimiter<K> = RateLimiter<K, DashMapStateStore<K>, QuantaClock>;

#[derive(Debug)]
struct LimiterInner {
account_recovery_per_requester: KeyedRateLimiter<RequesterFingerprint>,
account_recovery_per_email: KeyedRateLimiter<String>,
password_check_for_requester: KeyedRateLimiter<RequesterFingerprint>,
password_check_for_user: KeyedRateLimiter<Ulid>,
registration_per_requester: KeyedRateLimiter<RequesterFingerprint>,
}

impl LimiterInner {
fn new(config: &RateLimitingConfig) -> Option<Self> {
Some(Self {
password_check_for_requester: RateLimiter::keyed(config.login.per_address.to_quota()?),
account_recovery_per_requester: RateLimiter::keyed(
config.account_recovery.per_ip.to_quota()?,
),
account_recovery_per_email: RateLimiter::keyed(
config.account_recovery.per_address.to_quota()?,
),
password_check_for_requester: RateLimiter::keyed(config.login.per_ip.to_quota()?),
password_check_for_user: RateLimiter::keyed(config.login.per_account.to_quota()?),
registration_per_requester: RateLimiter::keyed(config.registration.to_quota()?),
})
}
}
Expand Down Expand Up @@ -105,14 +130,44 @@ impl Limiter {

loop {
// Call the retain_recent method on each rate limiter
this.inner.account_recovery_per_email.retain_recent();
this.inner.account_recovery_per_requester.retain_recent();
this.inner.password_check_for_requester.retain_recent();
this.inner.password_check_for_user.retain_recent();
this.inner.registration_per_requester.retain_recent();

interval.tick().await;
}
});
}

/// Check if an account recovery can be performed
///
/// # Errors
///
/// Returns an error if the operation is rate limited.
pub fn check_account_recovery(
&self,
requester: RequesterFingerprint,
email_address: &str,
) -> Result<(), AccountRecoveryLimitedError> {
self.inner
.account_recovery_per_requester
.check_key(&requester)
.map_err(|_| AccountRecoveryLimitedError::Requester(requester))?;

// Convert to lowercase to prevent bypassing the limit by enumerating different
// case variations.
// A case-folding transformation may be more proper.
let canonical_email = email_address.to_lowercase();
self.inner
.account_recovery_per_email
.check_key(&canonical_email)
.map_err(|_| AccountRecoveryLimitedError::Email(canonical_email))?;

Ok(())
}

/// Check if a password check can be performed
///
/// # Errors
Expand All @@ -135,6 +190,23 @@ impl Limiter {

Ok(())
}

/// Check if an account registration can be performed
///
/// # Errors
///
/// Returns an error if the operation is rate limited.
pub fn check_registration(
&self,
requester: RequesterFingerprint,
) -> Result<(), RegistrationLimitedError> {
self.inner
.registration_per_requester
.check_key(&requester)
.map_err(|_| RegistrationLimitedError::Requester(requester))?;

Ok(())
}
}

#[cfg(test)]
Expand Down
19 changes: 16 additions & 3 deletions crates/handlers/src/views/recovery/progress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use axum::{
response::{Html, IntoResponse, Response},
Form,
};
use hyper::StatusCode;
use mas_axum_utils::{
cookies::CookieJar,
csrf::{CsrfExt, ProtectedForm},
Expand All @@ -31,7 +32,7 @@ use mas_storage::{
use mas_templates::{EmptyContext, RecoveryProgressContext, TemplateContext, Templates};
use ulid::Ulid;

use crate::PreferredLanguage;
use crate::{Limiter, PreferredLanguage, RequesterFingerprint};

pub(crate) async fn get(
mut rng: BoxRng,
Expand Down Expand Up @@ -74,7 +75,7 @@ pub(crate) async fn get(
return Ok((cookie_jar, Html(rendered)).into_response());
}

let context = RecoveryProgressContext::new(recovery_session)
let context = RecoveryProgressContext::new(recovery_session, false)
.with_csrf(csrf_token.form_value())
.with_language(locale);

Expand All @@ -92,6 +93,7 @@ pub(crate) async fn post(
State(site_config): State<SiteConfig>,
State(templates): State<Templates>,
State(url_builder): State<UrlBuilder>,
(State(limiter), requester): (State<Limiter>, RequesterFingerprint),
PreferredLanguage(locale): PreferredLanguage,
cookie_jar: CookieJar,
Path(id): Path<Ulid>,
Expand Down Expand Up @@ -130,14 +132,25 @@ pub(crate) async fn post(
// Verify the CSRF token
let () = cookie_jar.verify_form(&clock, form)?;

// Check the rate limit if we are about to process the form
if let Err(e) = limiter.check_account_recovery(requester, &recovery_session.email) {
tracing::warn!(error = &e as &dyn std::error::Error);
let context = RecoveryProgressContext::new(recovery_session, true)
.with_csrf(csrf_token.form_value())
.with_language(locale);
let rendered = templates.render_recovery_progress(&context)?;

return Ok((StatusCode::TOO_MANY_REQUESTS, (cookie_jar, Html(rendered))).into_response());
}

// Schedule a new batch of emails
repo.job()
.schedule_job(SendAccountRecoveryEmailsJob::new(&recovery_session))
.await?;

repo.save().await?;

let context = RecoveryProgressContext::new(recovery_session)
let context = RecoveryProgressContext::new(recovery_session, false)
.with_csrf(csrf_token.form_value())
.with_language(locale);

Expand Down
13 changes: 11 additions & 2 deletions crates/handlers/src/views/recovery/start.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@ use mas_storage::{
BoxClock, BoxRepository, BoxRng,
};
use mas_templates::{
EmptyContext, FieldError, FormState, RecoveryStartContext, RecoveryStartFormField,
EmptyContext, FieldError, FormError, FormState, RecoveryStartContext, RecoveryStartFormField,
TemplateContext, Templates,
};
use serde::{Deserialize, Serialize};

use crate::{BoundActivityTracker, PreferredLanguage};
use crate::{BoundActivityTracker, Limiter, PreferredLanguage, RequesterFingerprint};

#[derive(Deserialize, Serialize)]
pub(crate) struct StartRecoveryForm {
Expand Down Expand Up @@ -90,6 +90,7 @@ pub(crate) async fn post(
State(site_config): State<SiteConfig>,
State(templates): State<Templates>,
State(url_builder): State<UrlBuilder>,
(State(limiter), requester): (State<Limiter>, RequesterFingerprint),
PreferredLanguage(locale): PreferredLanguage,
cookie_jar: CookieJar,
Form(form): Form<ProtectedForm<StartRecoveryForm>>,
Expand Down Expand Up @@ -120,6 +121,14 @@ pub(crate) async fn post(
form_state.with_error_on_field(RecoveryStartFormField::Email, FieldError::Invalid);
}

if form_state.is_valid() {
// Check the rate limit if we are about to process the form
if let Err(e) = limiter.check_account_recovery(requester, &form.email) {
tracing::warn!(error = &e as &dyn std::error::Error);
form_state.add_error_on_form(FormError::RateLimitExceeded);
}
}

if !form_state.is_valid() {
repo.save().await?;
let context = RecoveryStartContext::new()
Expand Down
Loading

0 comments on commit 5d4a4a6

Please sign in to comment.