Skip to content

Commit

Permalink
Add a safe FFI wrapper in wireguard-go-rs
Browse files Browse the repository at this point in the history
  • Loading branch information
hulthe committed May 29, 2024
1 parent 13d49d4 commit 7488bd9
Show file tree
Hide file tree
Showing 10 changed files with 348 additions and 221 deletions.
9 changes: 7 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 5 additions & 5 deletions talpid-wireguard/src/connectivity_check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@ use crate::{
use std::{
cmp,
net::Ipv4Addr,
sync::{mpsc, Mutex, Weak},
sync::{mpsc, Weak},
time::{Duration, Instant},
};
use tokio::sync::Mutex;

use super::{Tunnel, TunnelError};

Expand Down Expand Up @@ -214,8 +215,7 @@ impl ConnectivityMonitor {
fn get_stats(&self) -> Option<Result<StatsMap, Error>> {
self.tunnel_handle
.upgrade()?
.lock()
.ok()?
.blocking_lock()
.as_ref()
.and_then(|tunnel| match tunnel.get_tunnel_stats() {
Ok(stats) if stats.is_empty() => {
Expand Down Expand Up @@ -551,7 +551,7 @@ mod test {
rx_bytes: 0,
},
);
let peers = Mutex::new(map);
let peers = std::sync::Mutex::new(map);
Self {
on_get_stats: Box::new(move || {
let mut peers = peers.lock().unwrap();
Expand Down Expand Up @@ -746,7 +746,7 @@ mod test {
rx_bytes: 0,
},
);
let tunnel_stats = Mutex::new(map);
let tunnel_stats = std::sync::Mutex::new(map);

let pinger = MockPinger::default();
let (_tunnel_anchor, tunnel) = MockTunnel::new(move || {
Expand Down
40 changes: 25 additions & 15 deletions talpid-wireguard/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ impl Error {
pub struct WireguardMonitor {
runtime: tokio::runtime::Handle,
/// Tunnel implementation
tunnel: Arc<Mutex<Option<Box<dyn Tunnel>>>>,
tunnel: Arc<AsyncMutex<Option<Box<dyn Tunnel>>>>,
/// Callback to signal tunnel events
event_callback: EventCallback,
close_msg_receiver: sync_mpsc::Receiver<CloseMsg>,
Expand Down Expand Up @@ -306,7 +306,7 @@ impl WireguardMonitor {
let (pinger_tx, pinger_rx) = sync_mpsc::channel();
let monitor = WireguardMonitor {
runtime: args.runtime.clone(),
tunnel: Arc::new(Mutex::new(Some(tunnel))),
tunnel: Arc::new(AsyncMutex::new(Some(tunnel))),
event_callback,
close_msg_receiver: close_obfs_listener,
pinger_stop_sender: pinger_tx,
Expand Down Expand Up @@ -473,7 +473,7 @@ impl WireguardMonitor {

#[allow(clippy::too_many_arguments)]
async fn config_ephemeral_peers<F>(
tunnel: &Arc<Mutex<Option<Box<dyn Tunnel>>>>,
tunnel: &Arc<AsyncMutex<Option<Box<dyn Tunnel>>>>,
config: &mut Config,
retry_attempt: u32,
on_event: F,
Expand Down Expand Up @@ -579,7 +579,7 @@ impl WireguardMonitor {
#[cfg(daita)]
if config.daita {
// Start local DAITA machines
let mut tunnel = tunnel.lock().unwrap();
let mut tunnel = tunnel.lock().await;
if let Some(tunnel) = tunnel.as_mut() {
tunnel
.start_daita()
Expand All @@ -601,7 +601,7 @@ impl WireguardMonitor {
/// Reconfigures the tunnel to use the provided config while potentially modifying the config
/// and restarting the obfuscation provider. Returns the new config used by the new tunnel.
async fn reconfigure_tunnel(
tunnel: &Arc<Mutex<Option<Box<dyn Tunnel>>>>,
tunnel: &Arc<AsyncMutex<Option<Box<dyn Tunnel>>>>,
mut config: Config,
obfuscator: Arc<AsyncMutex<Option<ObfuscatorHandle>>>,
close_obfs_sender: sync_mpsc::Sender<CloseMsg>,
Expand All @@ -625,11 +625,12 @@ impl WireguardMonitor {
}
}

let tunnel = tunnel.lock().await;

let set_config_future = tunnel
.lock()
.unwrap()
.as_ref()
.map(|tunnel| tunnel.set_config(config.clone()));

if let Some(f) = set_config_future {
f.await
.map_err(Error::TunnelError)
Expand Down Expand Up @@ -847,7 +848,7 @@ impl WireguardMonitor {
}

fn stop_tunnel(&mut self) {
match self.tunnel.lock().expect("Tunnel lock poisoned").take() {
match self.tunnel.blocking_lock().take() {
Some(tunnel) => {
if let Err(e) = tunnel.stop() {
log::error!("{}", e.display_chain_with_msg("Failed to stop tunnel"));
Expand Down Expand Up @@ -1028,10 +1029,10 @@ pub(crate) trait Tunnel: Send {
fn get_interface_name(&self) -> String;
fn stop(self: Box<Self>) -> std::result::Result<(), TunnelError>;
fn get_tunnel_stats(&self) -> std::result::Result<stats::StatsMap, TunnelError>;
fn set_config(
&self,
fn set_config<'a>(
&'a self, // TODO: should be &mut ??
_config: Config,
) -> Pin<Box<dyn Future<Output = std::result::Result<(), TunnelError>> + Send>>;
) -> Pin<Box<dyn Future<Output = std::result::Result<(), TunnelError>> + Send + 'a>>;
#[cfg(daita)]
/// A [`Tunnel`] capable of using DAITA.
fn start_daita(&mut self) -> std::result::Result<(), TunnelError>;
Expand All @@ -1055,10 +1056,19 @@ pub enum TunnelError {
#[error("Failed to start wireguard tunnel")]
FatalStartWireguardError,

/// Failed to start the wireguard tunnel.
#[error("Failed to start wireguard-go tunnel: {status}")]
StartWireguardError {
// TODO: consider doing a Box<dyn Error> intead
/// Implementation-specific error code.
status: i32,
},

/// Failed to tear down wireguard tunnel.
#[error("Failed to stop wireguard tunnel. Status: {status}")]
#[error("Failed to stop wireguard tunnel: {status}")]
StopWireguardError {
/// Returned error code
// TODO: consider doing a Box<dyn Error> intead
/// Implementation-specific error code.
status: i32,
},

Expand Down Expand Up @@ -1114,8 +1124,8 @@ pub enum TunnelError {

/// Failed to receive DAITA event
#[cfg(daita)]
#[error("Failed to receive DAITA event")]
DaitaReceiveEvent(i32),
#[error("Failed to start DAITA")]
StartDaita(#[source] Box<dyn std::error::Error + Send>),

/// This tunnel does not support DAITA.
#[cfg(daita)]
Expand Down
46 changes: 0 additions & 46 deletions talpid-wireguard/src/wireguard_go/daita.rs

This file was deleted.

Loading

0 comments on commit 7488bd9

Please sign in to comment.