Skip to content

Commit

Permalink
Reset firewall rules after on_test_api_access_method
Browse files Browse the repository at this point in the history
  • Loading branch information
MarkusPettersson98 committed Dec 19, 2023
1 parent 04fc2ad commit 38670ac
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 32 deletions.
9 changes: 6 additions & 3 deletions mullvad-daemon/src/access_method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,7 @@ where
if api_access_method.get_id()
== self.get_current_access_method().await?.get_id() =>
{
// TODO(markus): I don't believe that this statement should be here.
self.connection_modes_handler.next().await?;
self.force_api_endpoint_rotation().await?;
}
_ => (),
}
Expand Down Expand Up @@ -169,7 +168,11 @@ where
/// Return the [`AccessMethodSetting`] which is currently used to access the
/// Mullvad API.
pub async fn get_current_access_method(&self) -> Result<AccessMethodSetting, Error> {
Ok(self.connection_modes_handler.get_access_method().await?)
self.connection_modes_handler
.get_current()
.await
.map(|current| current.setting)
.map_err(Error::ConnectionMode)
}

/// Change which [`AccessMethodSetting`] which will be used as the Mullvad
Expand Down
44 changes: 19 additions & 25 deletions mullvad-daemon/src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use talpid_types::net::{
};

pub enum Message {
Get(ResponseTx<AccessMethodSetting>),
Get(ResponseTx<ResolvedConnectionMode>),
Set(ResponseTx<()>, AccessMethodSetting),
Next(ResponseTx<ApiConnectionMode>),
Update(ResponseTx<()>, Vec<AccessMethodSetting>),
Expand All @@ -48,9 +48,12 @@ pub enum AccessMethodEvent {
}

// TODO(markus): Comment this struct
#[derive(Clone)]
pub struct ResolvedConnectionMode {
pub connection_mode: ApiConnectionMode,
pub endpoint: AllowedEndpoint,
/// This is the setting which was resolved into `connection_mode` and `endpoint`.
pub setting: AccessMethodSetting,
}

#[derive(err_derive::Error, Debug)]
Expand Down Expand Up @@ -84,7 +87,7 @@ impl AccessModeSelectorHandle {
rx.await.map_err(Error::NotRunning)?
}

pub async fn get_access_method(&self) -> Result<AccessMethodSetting> {
pub async fn get_current(&self) -> Result<ResolvedConnectionMode> {
self.send_command(Message::Get).await.map_err(|err| {
log::error!("Failed to get current access method!");
err
Expand Down Expand Up @@ -163,7 +166,7 @@ pub struct AccessModeSelector {
address_cache: AddressCache,
/// All listeners of [`AccessMethodEvent`]s.
listeners: Vec<Box<dyn Sender<AccessMethodEvent> + Send>>,
last_resolved_connection_mode: ResolvedConnectionMode,
current: ResolvedConnectionMode,
}

// TODO(markus): Document this! It was created to get an initial api endpoint in
Expand Down Expand Up @@ -209,7 +212,7 @@ impl AccessModeSelector {
connection_modes,
address_cache,
listeners: vec![Box::new(listener)],
last_resolved_connection_mode: initial_connection_mode,
current: initial_connection_mode,
};

tokio::spawn(selector.into_future());
Expand Down Expand Up @@ -247,13 +250,8 @@ impl AccessModeSelector {
Ok(())
}

fn on_get_access_method(&mut self, tx: ResponseTx<AccessMethodSetting>) -> Result<()> {
let value = self.get_access_method();
self.reply(tx, value)
}

fn get_access_method(&mut self) -> AccessMethodSetting {
self.connection_modes.peek()
fn on_get_access_method(&mut self, tx: ResponseTx<ResolvedConnectionMode>) -> Result<()> {
self.reply(tx, self.current.clone())
}

fn on_set_access_method(
Expand All @@ -279,23 +277,23 @@ impl AccessModeSelector {
async fn next_connection_mode(&mut self) -> ApiConnectionMode {
let access_method = self.connection_modes.next().unwrap();
let next = {
let resolved = self.resolve(&access_method).await;
self.current = resolved.clone();
let ResolvedConnectionMode {
connection_mode,
endpoint,
} = self.resolve(&access_method).await;
let event = AccessMethodEvent::Active {
settings: access_method,
setting: settings,
endpoint,
};
..
} = resolved.clone();
let event = AccessMethodEvent::Active { settings, endpoint };
self.listeners
.retain(|listener| listener.send(event.clone()).is_ok());
connection_mode
resolved
};

// Save the new connection mode to cache!
{
let cache_dir = self.cache_dir.clone();
let new_connection_mode = next.clone();
let new_connection_mode = next.connection_mode.clone();
tokio::spawn(async move {
if new_connection_mode.save(&cache_dir).await.is_err() {
log::warn!(
Expand All @@ -305,7 +303,7 @@ impl AccessModeSelector {
}
});
}
next
next.connection_mode
}
fn on_update_access_methods(
&mut self,
Expand Down Expand Up @@ -345,6 +343,7 @@ impl AccessModeSelector {
ResolvedConnectionMode {
connection_mode,
endpoint,
setting: access_method.clone(),
}
}
}
Expand Down Expand Up @@ -446,11 +445,6 @@ impl ConnectionModesIterator {
Ok(Box::new(access_methods.into_iter().cycle()))
}
}

/// Look at the currently active [`AccessMethod`]
pub fn peek(&self) -> AccessMethodSetting {
self.current.clone()
}
}

impl Iterator for ConnectionModesIterator {
Expand Down
19 changes: 15 additions & 4 deletions mullvad-daemon/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2400,8 +2400,9 @@ where
let handle = self.connection_modes_handler.clone();
tokio::spawn(async move {
let result = handle
.get_access_method()
.get_current()
.await
.map(|current| current.setting)
.map_err(Error::ApiConnectionModeError);
Self::oneshot_send(tx, result, "get_current_api_access_method response");
});
Expand All @@ -2419,12 +2420,13 @@ where

match access_method_lookup {
Ok(access_method) => {
let access_method_selector = self.connection_modes_handler.clone();
// Create a stream of the access method to test.
let api::ResolvedConnectionMode {
connection_mode,
endpoint,
} = self
.connection_modes_handler
..
} = access_method_selector
.resolve(access_method.clone())
.await
// TODO(markus): Do not unwrap!
Expand Down Expand Up @@ -2456,9 +2458,18 @@ where
.await
.map_err(Error::RestError);

// TODO(markus):
// Tell the daemon to reset the hole we just punched to whatever was in place before.

// TODO(markus): Do not unwrap
let api::ResolvedConnectionMode { endpoint, .. } =
access_method_selector.get_current().await.unwrap();
let event = api::AccessMethodEvent::Testing {
endpoint,
update_finished_tx: update_finished_tx.clone(),
};
let _ = sender.send(event);
let _ = update_finished_rx.next().await;

log::info!(
"The result of testing {method:?} is {result}",
method = access_method.access_method,
Expand Down

0 comments on commit 38670ac

Please sign in to comment.