Skip to content

Commit

Permalink
Consolidate two mutexes into one
Browse files Browse the repository at this point in the history
  • Loading branch information
MarkusPettersson98 committed Sep 25, 2024
1 parent 7b75e8d commit bb10ecc
Show file tree
Hide file tree
Showing 7 changed files with 168 additions and 167 deletions.
274 changes: 137 additions & 137 deletions mullvad-api/src/availability.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
use std::{
future::Future,
sync::{Arc, Mutex},
sync::{Arc, Mutex, MutexGuard},
time::Duration,
};
use tokio::sync::broadcast;

const CHANNEL_CAPACITY: usize = 100;

/// Pause background requests if [ApiAvailabilityHandle::reset_inactivity_timer] hasn't been
/// called for this long.
const INACTIVITY_TIME: Duration = Duration::from_secs(3 * 24 * 60 * 60);
Expand All @@ -26,182 +24,108 @@ pub struct State {
inactive: bool,
}

#[derive(Clone, Debug)]
pub struct ApiAvailability(Arc<Mutex<ApiAvailabilityState>>);

#[derive(Debug)]
struct ApiAvailabilityState {
tx: broadcast::Sender<State>,
state: State,
inactivity_timer: Option<tokio::task::JoinHandle<()>>,
}

impl State {
pub fn is_suspended(&self) -> bool {
pub const fn is_suspended(&self) -> bool {
self.suspended
}

pub fn is_background_paused(&self) -> bool {
pub const fn is_background_paused(&self) -> bool {
self.offline || self.pause_background || self.suspended || self.inactive
}

pub fn is_offline(&self) -> bool {
pub const fn is_offline(&self) -> bool {
self.offline
}
}

pub struct ApiAvailability {
state: Arc<Mutex<State>>,
tx: broadcast::Sender<State>,

inactivity_timer: Arc<Mutex<Option<tokio::task::JoinHandle<()>>>>,
}

impl ApiAvailability {
pub fn new(initial_state: State) -> Self {
let (tx, _rx) = broadcast::channel(CHANNEL_CAPACITY);
let state = Arc::new(Mutex::new(initial_state));
const CHANNEL_CAPACITY: usize = 100;

let availability = ApiAvailability {
state,
pub fn new(initial_state: State) -> Self {
let (tx, _rx) = broadcast::channel(ApiAvailability::CHANNEL_CAPACITY);
let inner = ApiAvailabilityState {
state: initial_state,
inactivity_timer: None,
tx,
inactivity_timer: Arc::new(Mutex::new(None)),
};
availability.handle().reset_inactivity_timer();
availability
let handle = ApiAvailability(Arc::new(Mutex::new(inner)));
// Start an inactivity timer
handle.reset_inactivity_timer();
handle
}

pub fn get_state(&self) -> State {
*self.state.lock().unwrap()
fn acquire(&self) -> MutexGuard<'_, ApiAvailabilityState> {
self.0.lock().unwrap()
}

pub fn handle(&self) -> ApiAvailabilityHandle {
ApiAvailabilityHandle {
state: self.state.clone(),
tx: self.tx.clone(),
inactivity_timer: self.inactivity_timer.clone(),
}
}
}

impl Drop for ApiAvailability {
fn drop(&mut self) {
if let Some(timer) = self.inactivity_timer.lock().unwrap().take() {
timer.abort();
}
}
}

#[derive(Clone, Debug)]
pub struct ApiAvailabilityHandle {
state: Arc<Mutex<State>>,
tx: broadcast::Sender<State>,
inactivity_timer: Arc<Mutex<Option<tokio::task::JoinHandle<()>>>>,
}

impl ApiAvailabilityHandle {
/// Reset task that automatically pauses API requests due inactivity,
/// starting it if it's not currently running.
pub fn reset_inactivity_timer(&self) {
log::trace!("Restarting API inactivity check");

let self_ = self.clone();

let mut inactivity_timer = self.inactivity_timer.lock().unwrap();
if let Some(timer) = inactivity_timer.take() {
timer.abort();
}

self.set_active();

*inactivity_timer = Some(tokio::spawn(async move {
let mut inner = self.0.lock().unwrap();
log::debug!("Restarting API inactivity check");
inner.stop_inactivity_timer();
let availability_handle = self.clone();
inner.inactivity_timer = Some(tokio::spawn(async move {
talpid_time::sleep(INACTIVITY_TIME).await;
self_.set_inactive();
availability_handle.set_inactive();
}));
inner.set_active();
}

/// Stops timer that pauses API requests due to inactivity.
pub fn stop_inactivity_timer(&self) {
log::trace!("Stopping API inactivity check");

let mut inactivity_timer = self.inactivity_timer.lock().unwrap();
if let Some(timer) = inactivity_timer.take() {
timer.abort();
}
self.set_active();
}

fn inactivity_timer_running(&self) -> bool {
self.inactivity_timer.lock().unwrap().is_some()
}

pub fn suspend(&self) {
let mut state = self.state.lock().unwrap();
if !state.suspended {
log::debug!("Suspending API requests");

state.suspended = true;
let _ = self.tx.send(*state);
}
}

pub fn unsuspend(&self) {
let mut state = self.state.lock().unwrap();
if state.suspended {
log::debug!("Unsuspending API requests");

state.suspended = false;
let _ = self.tx.send(*state);
}
self.acquire().stop_inactivity_timer();
}

pub fn pause_background(&self) {
let mut state = self.state.lock().unwrap();
if !state.pause_background {
log::debug!("Pausing background API requests");

state.pause_background = true;
let _ = self.tx.send(*state);
}
self.acquire().pause_background();
}

pub fn resume_background(&self) {
if self.inactivity_timer_running() {
let should_reset = {
let mut inner = self.acquire();
inner.pause_background();
inner.inactivity_timer_running()
};
// Note: It is important that we do not hold on to the Mutex when calling `reset_inactivity_timer()`.
if should_reset {
self.reset_inactivity_timer();
}

let mut state = self.state.lock().unwrap();
if state.pause_background {
log::debug!("Resuming background API requests");
state.pause_background = false;
let _ = self.tx.send(*state);
}
}

fn set_inactive(&self) {
let mut state = self.state.lock().unwrap();
if !state.inactive {
log::debug!("Pausing background API requests due to inactivity");
state.inactive = true;
let _ = self.tx.send(*state);
}
pub fn suspend(&self) {
self.acquire().suspend()
}

fn set_active(&self) {
let mut state = self.state.lock().unwrap();
if state.inactive {
log::debug!("Resuming background API requests due to activity");
state.inactive = false;
let _ = self.tx.send(*state);
}
pub fn unsuspend(&self) {
self.acquire().unsuspend();
}

pub fn set_offline(&self, offline: bool) {
let mut state = self.state.lock().unwrap();
if state.offline != offline {
if offline {
log::debug!("Pausing API requests due to being offline");
} else {
log::debug!("Resuming API requests due to coming online");
}
self.acquire().set_offline(offline);
}

state.offline = offline;
let _ = self.tx.send(*state);
}
fn set_inactive(&self) {
self.acquire().set_inactive();
}

/// Check if the host is offline
pub fn is_offline(&self) -> bool {
self.get_state().is_offline()
}

pub fn get_state(&self) -> State {
*self.state.lock().unwrap()
fn get_state(&self) -> State {
self.acquire().state
}

pub fn wait_for_unsuspend(&self) -> impl Future<Output = Result<(), Error>> {
Expand Down Expand Up @@ -236,12 +160,12 @@ impl ApiAvailabilityHandle {
&self,
state_ready: impl Fn(State) -> bool,
) -> impl Future<Output = Result<(), Error>> {
let mut rx = self.tx.subscribe();
let state = self.state.clone();
let mut rx = self.acquire().tx.subscribe();

let handle = self.clone();
async move {
let current_state = { *state.lock().unwrap() };
if state_ready(current_state) {
let state = handle.get_state();
if state_ready(state) {
return Ok(());
}

Expand All @@ -254,3 +178,79 @@ impl ApiAvailabilityHandle {
}
}
}

impl ApiAvailabilityState {
fn suspend(&mut self) {
if !self.state.suspended {
log::trace!("Suspending API requests");
self.state.suspended = true;
let _ = self.tx.send(self.state);
}
}

fn unsuspend(&mut self) {
if self.state.suspended {
log::trace!("Unsuspending API requests");
self.state.suspended = false;
let _ = self.tx.send(self.state);
}
}

fn set_inactive(&mut self) {
log::trace!("Settings state to inactive");
if !self.state.inactive {
log::debug!("Pausing background API requests due to inactivity");
self.state.inactive = true;
let _ = self.tx.send(self.state);
}
}

fn set_active(&mut self) {
log::trace!("Settings state to active");
if self.state.inactive {
log::debug!("Resuming background API requests due to activity");
self.state.inactive = false;
let _ = self.tx.send(self.state).inspect_err(|send_err| {
log::debug!("All receivers of state updates have been dropped");
log::debug!("{send_err}");
});
}
}

fn set_offline(&mut self, offline: bool) {
if offline {
log::debug!("Pausing API requests due to being offline");
} else {
log::debug!("Resuming API requests due to coming online");
}
if self.state.offline != offline {
self.state.offline = offline;
let _ = self.tx.send(self.state);
}
}

fn pause_background(&mut self) {
if !self.state.pause_background {
log::debug!("Pausing background API requests");
self.state.pause_background = true;
let _ = self.tx.send(self.state);
}
}

fn stop_inactivity_timer(&mut self) {
log::debug!("Stopping API inactivity check");
if let Some(timer) = self.inactivity_timer.take() {
timer.abort();
}
}

const fn inactivity_timer_running(&self) -> bool {
self.inactivity_timer.is_some()
}
}

impl Drop for ApiAvailabilityState {
fn drop(&mut self) {
self.stop_inactivity_timer();
}
}
8 changes: 4 additions & 4 deletions mullvad-api/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use std::{
use talpid_types::ErrorExt;

pub mod availability;
use availability::{ApiAvailability, ApiAvailabilityHandle};
use availability::ApiAvailability;
pub mod rest;

mod abortable_stream;
Expand Down Expand Up @@ -414,7 +414,7 @@ impl Runtime {
) -> rest::RequestServiceHandle {
rest::RequestService::spawn(
sni_hostname,
self.api_availability.handle(),
self.api_availability.clone(),
self.address_cache.clone(),
connection_mode_provider,
#[cfg(target_os = "android")]
Expand Down Expand Up @@ -467,8 +467,8 @@ impl Runtime {
&mut self.handle
}

pub fn availability_handle(&self) -> ApiAvailabilityHandle {
self.api_availability.handle()
pub fn availability_handle(&self) -> ApiAvailability {
self.api_availability.clone()
}

pub fn address_cache(&self) -> &AddressCache {
Expand Down
Loading

0 comments on commit bb10ecc

Please sign in to comment.