diff --git a/mullvad-api/src/availability.rs b/mullvad-api/src/availability.rs index ba33836b16ce..339aca8bcab1 100644 --- a/mullvad-api/src/availability.rs +++ b/mullvad-api/src/availability.rs @@ -1,12 +1,10 @@ use std::{ future::Future, - sync::{Arc, Mutex}, + sync::{Arc, Mutex, MutexGuard}, time::Duration, }; use tokio::sync::broadcast; -const CHANNEL_CAPACITY: usize = 100; - /// Pause background requests if [ApiAvailabilityHandle::reset_inactivity_timer] hasn't been /// called for this long. const INACTIVITY_TIME: Duration = Duration::from_secs(3 * 24 * 60 * 60); @@ -26,182 +24,108 @@ pub struct State { inactive: bool, } +#[derive(Clone, Debug)] +pub struct ApiAvailability(Arc>); + +#[derive(Debug)] +struct ApiAvailabilityState { + tx: broadcast::Sender, + state: State, + inactivity_timer: Option>, +} + impl State { - pub fn is_suspended(&self) -> bool { + pub const fn is_suspended(&self) -> bool { self.suspended } - pub fn is_background_paused(&self) -> bool { + pub const fn is_background_paused(&self) -> bool { self.offline || self.pause_background || self.suspended || self.inactive } - pub fn is_offline(&self) -> bool { + pub const fn is_offline(&self) -> bool { self.offline } } -pub struct ApiAvailability { - state: Arc>, - tx: broadcast::Sender, - - inactivity_timer: Arc>>>, -} - impl ApiAvailability { - pub fn new(initial_state: State) -> Self { - let (tx, _rx) = broadcast::channel(CHANNEL_CAPACITY); - let state = Arc::new(Mutex::new(initial_state)); + const CHANNEL_CAPACITY: usize = 100; - let availability = ApiAvailability { - state, + pub fn new(initial_state: State) -> Self { + let (tx, _rx) = broadcast::channel(ApiAvailability::CHANNEL_CAPACITY); + let inner = ApiAvailabilityState { + state: initial_state, + inactivity_timer: None, tx, - inactivity_timer: Arc::new(Mutex::new(None)), }; - availability.handle().reset_inactivity_timer(); - availability + let handle = ApiAvailability(Arc::new(Mutex::new(inner))); + // Start an inactivity timer + handle.reset_inactivity_timer(); + handle } - pub fn get_state(&self) -> State { - *self.state.lock().unwrap() + fn acquire(&self) -> MutexGuard<'_, ApiAvailabilityState> { + self.0.lock().unwrap() } - pub fn handle(&self) -> ApiAvailabilityHandle { - ApiAvailabilityHandle { - state: self.state.clone(), - tx: self.tx.clone(), - inactivity_timer: self.inactivity_timer.clone(), - } - } -} - -impl Drop for ApiAvailability { - fn drop(&mut self) { - if let Some(timer) = self.inactivity_timer.lock().unwrap().take() { - timer.abort(); - } - } -} - -#[derive(Clone, Debug)] -pub struct ApiAvailabilityHandle { - state: Arc>, - tx: broadcast::Sender, - inactivity_timer: Arc>>>, -} - -impl ApiAvailabilityHandle { /// Reset task that automatically pauses API requests due inactivity, /// starting it if it's not currently running. pub fn reset_inactivity_timer(&self) { - log::trace!("Restarting API inactivity check"); - - let self_ = self.clone(); - - let mut inactivity_timer = self.inactivity_timer.lock().unwrap(); - if let Some(timer) = inactivity_timer.take() { - timer.abort(); - } - - self.set_active(); - - *inactivity_timer = Some(tokio::spawn(async move { + let mut inner = self.0.lock().unwrap(); + log::debug!("Restarting API inactivity check"); + inner.stop_inactivity_timer(); + let availability_handle = self.clone(); + inner.inactivity_timer = Some(tokio::spawn(async move { talpid_time::sleep(INACTIVITY_TIME).await; - self_.set_inactive(); + availability_handle.set_inactive(); })); + inner.set_active(); } /// Stops timer that pauses API requests due to inactivity. pub fn stop_inactivity_timer(&self) { - log::trace!("Stopping API inactivity check"); - - let mut inactivity_timer = self.inactivity_timer.lock().unwrap(); - if let Some(timer) = inactivity_timer.take() { - timer.abort(); - } - self.set_active(); - } - - fn inactivity_timer_running(&self) -> bool { - self.inactivity_timer.lock().unwrap().is_some() - } - - pub fn suspend(&self) { - let mut state = self.state.lock().unwrap(); - if !state.suspended { - log::debug!("Suspending API requests"); - - state.suspended = true; - let _ = self.tx.send(*state); - } - } - - pub fn unsuspend(&self) { - let mut state = self.state.lock().unwrap(); - if state.suspended { - log::debug!("Unsuspending API requests"); - - state.suspended = false; - let _ = self.tx.send(*state); - } + self.acquire().stop_inactivity_timer(); } pub fn pause_background(&self) { - let mut state = self.state.lock().unwrap(); - if !state.pause_background { - log::debug!("Pausing background API requests"); - - state.pause_background = true; - let _ = self.tx.send(*state); - } + self.acquire().pause_background(); } pub fn resume_background(&self) { - if self.inactivity_timer_running() { + let should_reset = { + let mut inner = self.acquire(); + inner.pause_background(); + inner.inactivity_timer_running() + }; + // Note: It is important that we do not hold on to the Mutex when calling `reset_inactivity_timer()`. + if should_reset { self.reset_inactivity_timer(); } - - let mut state = self.state.lock().unwrap(); - if state.pause_background { - log::debug!("Resuming background API requests"); - state.pause_background = false; - let _ = self.tx.send(*state); - } } - fn set_inactive(&self) { - let mut state = self.state.lock().unwrap(); - if !state.inactive { - log::debug!("Pausing background API requests due to inactivity"); - state.inactive = true; - let _ = self.tx.send(*state); - } + pub fn suspend(&self) { + self.acquire().suspend() } - fn set_active(&self) { - let mut state = self.state.lock().unwrap(); - if state.inactive { - log::debug!("Resuming background API requests due to activity"); - state.inactive = false; - let _ = self.tx.send(*state); - } + pub fn unsuspend(&self) { + self.acquire().unsuspend(); } pub fn set_offline(&self, offline: bool) { - let mut state = self.state.lock().unwrap(); - if state.offline != offline { - if offline { - log::debug!("Pausing API requests due to being offline"); - } else { - log::debug!("Resuming API requests due to coming online"); - } + self.acquire().set_offline(offline); + } - state.offline = offline; - let _ = self.tx.send(*state); - } + fn set_inactive(&self) { + self.acquire().set_inactive(); + } + + /// Check if the host is offline + pub fn is_offline(&self) -> bool { + self.get_state().is_offline() } - pub fn get_state(&self) -> State { - *self.state.lock().unwrap() + fn get_state(&self) -> State { + self.acquire().state } pub fn wait_for_unsuspend(&self) -> impl Future> { @@ -236,12 +160,12 @@ impl ApiAvailabilityHandle { &self, state_ready: impl Fn(State) -> bool, ) -> impl Future> { - let mut rx = self.tx.subscribe(); - let state = self.state.clone(); + let mut rx = { self.acquire().tx.subscribe() }; + let handle = self.clone(); async move { - let current_state = { *state.lock().unwrap() }; - if state_ready(current_state) { + let state = handle.get_state(); + if state_ready(state) { return Ok(()); } @@ -254,3 +178,79 @@ impl ApiAvailabilityHandle { } } } + +impl ApiAvailabilityState { + fn suspend(&mut self) { + if !self.state.suspended { + log::trace!("Suspending API requests"); + self.state.suspended = true; + let _ = self.tx.send(self.state); + } + } + + fn unsuspend(&mut self) { + if self.state.suspended { + log::trace!("Unsuspending API requests"); + self.state.suspended = false; + let _ = self.tx.send(self.state); + } + } + + fn set_inactive(&mut self) { + log::trace!("Settings state to inactive"); + if !self.state.inactive { + log::debug!("Pausing background API requests due to inactivity"); + self.state.inactive = true; + let _ = self.tx.send(self.state); + } + } + + fn set_active(&mut self) { + log::trace!("Settings state to active"); + if self.state.inactive { + log::debug!("Resuming background API requests due to activity"); + self.state.inactive = false; + let _ = self.tx.send(self.state).inspect_err(|send_err| { + log::debug!("All receivers of state updates have been dropped"); + log::debug!("{send_err}"); + }); + } + } + + fn set_offline(&mut self, offline: bool) { + if offline { + log::debug!("Pausing API requests due to being offline"); + } else { + log::debug!("Resuming API requests due to coming online"); + } + if self.state.offline != offline { + self.state.offline = offline; + let _ = self.tx.send(self.state); + } + } + + fn pause_background(&mut self) { + if !self.state.pause_background { + log::debug!("Pausing background API requests"); + self.state.pause_background = true; + let _ = self.tx.send(self.state); + } + } + + fn stop_inactivity_timer(&mut self) { + log::debug!("Stopping API inactivity check"); + if let Some(timer) = self.inactivity_timer.take() { + timer.abort(); + } + } + + const fn inactivity_timer_running(&self) -> bool { + self.inactivity_timer.is_some() + } +} + +impl Drop for ApiAvailabilityState { + fn drop(&mut self) { + self.stop_inactivity_timer(); + } +} diff --git a/mullvad-api/src/lib.rs b/mullvad-api/src/lib.rs index 87b6e3d6567a..8add11d30a3f 100644 --- a/mullvad-api/src/lib.rs +++ b/mullvad-api/src/lib.rs @@ -21,7 +21,7 @@ use std::{ use talpid_types::ErrorExt; pub mod availability; -use availability::{ApiAvailability, ApiAvailabilityHandle}; +use availability::ApiAvailability; pub mod rest; mod abortable_stream; @@ -414,7 +414,7 @@ impl Runtime { ) -> rest::RequestServiceHandle { rest::RequestService::spawn( sni_hostname, - self.api_availability.handle(), + self.api_availability.clone(), self.address_cache.clone(), connection_mode_provider, #[cfg(target_os = "android")] @@ -467,8 +467,8 @@ impl Runtime { &mut self.handle } - pub fn availability_handle(&self) -> ApiAvailabilityHandle { - self.api_availability.handle() + pub fn availability_handle(&self) -> ApiAvailability { + self.api_availability.clone() } pub fn address_cache(&self) -> &AddressCache { diff --git a/mullvad-api/src/rest.rs b/mullvad-api/src/rest.rs index 238d73206ab9..bbcef79903f6 100644 --- a/mullvad-api/src/rest.rs +++ b/mullvad-api/src/rest.rs @@ -3,7 +3,7 @@ pub use crate::https_client_with_sni::SocketBypassRequest; use crate::{ access::AccessTokenStore, address_cache::AddressCache, - availability::ApiAvailabilityHandle, + availability::ApiAvailability, https_client_with_sni::{HttpsConnectorWithSni, HttpsConnectorWithSniHandle}, proxy::ConnectionModeProvider, }; @@ -122,14 +122,14 @@ pub(crate) struct RequestService { client: hyper::Client, connection_mode_provider: T, connection_mode_generation: usize, - api_availability: ApiAvailabilityHandle, + api_availability: ApiAvailability, } impl RequestService { /// Constructs a new request service. pub fn spawn( sni_hostname: Option, - api_availability: ApiAvailabilityHandle, + api_availability: ApiAvailability, address_cache: AddressCache, connection_mode_provider: T, #[cfg(target_os = "android")] socket_bypass_tx: Option>, @@ -218,7 +218,7 @@ impl RequestService { // Switch API endpoint if the request failed due to a network error if let Err(err) = &response { - if err.is_network_error() && !api_availability.get_state().is_offline() { + if err.is_network_error() && !api_availability.is_offline() { log::error!("{}", err.display_chain_with_msg("HTTP request failed")); if let Some(tx) = tx { let _ = tx.unbounded_send(RequestCommand::NextApiConfig( @@ -339,7 +339,7 @@ impl Request { async fn into_future( self, hyper_client: hyper::Client, - api_availability: ApiAvailabilityHandle, + api_availability: ApiAvailability, ) -> Result { let timeout = self.timeout; let inner_fut = self.into_future_without_timeout(hyper_client, api_availability); @@ -351,7 +351,7 @@ impl Request { async fn into_future_without_timeout( mut self, hyper_client: hyper::Client, - api_availability: ApiAvailabilityHandle, + api_availability: ApiAvailability, ) -> Result { let _ = api_availability.wait_for_unsuspend().await; @@ -605,14 +605,14 @@ async fn deserialize_body_inner( pub struct MullvadRestHandle { pub(crate) service: RequestServiceHandle, pub factory: RequestFactory, - pub availability: ApiAvailabilityHandle, + pub availability: ApiAvailability, } impl MullvadRestHandle { pub(crate) fn new( service: RequestServiceHandle, factory: RequestFactory, - availability: ApiAvailabilityHandle, + availability: ApiAvailability, ) -> Self { Self { service, diff --git a/mullvad-daemon/src/api.rs b/mullvad-daemon/src/api.rs index acfdbb766400..ac54382a57eb 100644 --- a/mullvad-daemon/src/api.rs +++ b/mullvad-daemon/src/api.rs @@ -11,7 +11,7 @@ use futures::{ StreamExt, }; use mullvad_api::{ - availability::ApiAvailabilityHandle, + availability::ApiAvailability, proxy::{ApiConnectionMode, ConnectionModeProvider, ProxyConfig}, AddressCache, }; @@ -578,9 +578,9 @@ pub fn allowed_clients(connection_mode: &ApiConnectionMode) -> AllowedClients { } } -/// Forwards the received values from `offline_state_rx` to the [`ApiAvailabilityHandle`]. +/// Forwards the received values from `offline_state_rx` to the [`ApiAvailability`]. pub(crate) fn forward_offline_state( - api_availability: ApiAvailabilityHandle, + api_availability: ApiAvailability, mut offline_state_rx: mpsc::UnboundedReceiver, ) { tokio::spawn(async move { diff --git a/mullvad-daemon/src/device/service.rs b/mullvad-daemon/src/device/service.rs index c4c949cebae9..093dd14f6746 100644 --- a/mullvad-daemon/src/device/service.rs +++ b/mullvad-daemon/src/device/service.rs @@ -13,7 +13,7 @@ use talpid_types::net::wireguard::PrivateKey; use super::{Error, PrivateAccountAndDevice, PrivateDevice}; use mullvad_api::{ - availability::ApiAvailabilityHandle, + availability::ApiAvailability, rest::{self, MullvadRestHandle}, AccountsProxy, DevicesProxy, }; @@ -28,12 +28,12 @@ const RETRY_BACKOFF_STRATEGY: Jittered = Jittered::jitter( #[derive(Clone)] pub struct DeviceService { - api_availability: ApiAvailabilityHandle, + api_availability: ApiAvailability, proxy: DevicesProxy, } impl DeviceService { - pub fn new(handle: rest::MullvadRestHandle, api_availability: ApiAvailabilityHandle) -> Self { + pub fn new(handle: rest::MullvadRestHandle, api_availability: ApiAvailability) -> Self { Self { proxy: DevicesProxy::new(handle), api_availability, @@ -255,7 +255,7 @@ impl DeviceService { #[derive(Clone)] pub struct AccountService { - api_availability: ApiAvailabilityHandle, + api_availability: ApiAvailability, initial_check_abort_handle: AbortHandle, proxy: AccountsProxy, } @@ -368,7 +368,7 @@ impl AccountService { pub fn spawn_account_service( api_handle: MullvadRestHandle, token: Option, - api_availability: ApiAvailabilityHandle, + api_availability: ApiAvailability, ) -> AccountService { let accounts_proxy = AccountsProxy::new(api_handle); api_availability.pause_background(); @@ -403,7 +403,7 @@ pub fn spawn_account_service( fn handle_account_data_result( result: &Result, - api_availability: &ApiAvailabilityHandle, + api_availability: &ApiAvailability, ) -> bool { match result { Ok(_data) if _data.expiry >= chrono::Utc::now() => { @@ -425,9 +425,9 @@ fn handle_account_data_result( } } -fn should_retry(result: &Result, api_handle: &ApiAvailabilityHandle) -> bool { +fn should_retry(result: &Result, api_handle: &ApiAvailability) -> bool { match result { - Err(error) if error.is_network_error() => !api_handle.get_state().is_offline(), + Err(error) if error.is_network_error() => !api_handle.is_offline(), _ => false, } } diff --git a/mullvad-daemon/src/relay_list/mod.rs b/mullvad-daemon/src/relay_list/mod.rs index 2b4be3db5432..99fa60df57c8 100644 --- a/mullvad-daemon/src/relay_list/mod.rs +++ b/mullvad-daemon/src/relay_list/mod.rs @@ -11,7 +11,7 @@ use std::{ }; use tokio::fs::File; -use mullvad_api::{availability::ApiAvailabilityHandle, rest::MullvadRestHandle, RelayListProxy}; +use mullvad_api::{availability::ApiAvailability, rest::MullvadRestHandle, RelayListProxy}; use mullvad_relay_selector::RelaySelector; use mullvad_types::relay_list::RelayList; use talpid_future::retry::{retry_future, ExponentialBackoff, Jittered}; @@ -68,7 +68,7 @@ pub struct RelayListUpdater { relay_selector: RelaySelector, on_update: Box, last_check: SystemTime, - api_availability: ApiAvailabilityHandle, + api_availability: ApiAvailability, } impl RelayListUpdater { @@ -163,7 +163,7 @@ impl RelayListUpdater { } fn download_relay_list( - api_handle: ApiAvailabilityHandle, + api_handle: ApiAvailability, proxy: RelayListProxy, tag: Option, ) -> impl Future, mullvad_api::Error>> + 'static { diff --git a/mullvad-daemon/src/version_check.rs b/mullvad-daemon/src/version_check.rs index 4aae30f574d3..ed50eb6ff570 100644 --- a/mullvad-daemon/src/version_check.rs +++ b/mullvad-daemon/src/version_check.rs @@ -4,7 +4,7 @@ use futures::{ future::{BoxFuture, FusedFuture}, FutureExt, SinkExt, StreamExt, TryFutureExt, }; -use mullvad_api::{availability::ApiAvailabilityHandle, rest::MullvadRestHandle, AppVersionProxy}; +use mullvad_api::{availability::ApiAvailability, rest::MullvadRestHandle, AppVersionProxy}; use mullvad_types::version::{AppVersionInfo, ParsedAppVersion}; use serde::{Deserialize, Serialize}; use std::{ @@ -149,7 +149,7 @@ impl VersionUpdaterHandle { impl VersionUpdater { pub async fn spawn( mut api_handle: MullvadRestHandle, - availability_handle: ApiAvailabilityHandle, + availability_handle: ApiAvailability, cache_dir: PathBuf, update_sender: DaemonEventSender, show_beta_releases: bool, @@ -413,7 +413,7 @@ impl UpdateContext { #[derive(Clone)] struct ApiContext { - api_handle: ApiAvailabilityHandle, + api_handle: ApiAvailability, version_proxy: AppVersionProxy, platform_version: String, } @@ -435,7 +435,7 @@ fn do_version_check( // retry immediately on network errors (unless we're offline) let should_retry_immediate = move |result: &Result<_, Error>| { if let Err(Error::Download(error)) = result { - error.is_network_error() && !api.api_handle.get_state().is_offline() + error.is_network_error() && !api.api_handle.is_offline() } else { false }