diff --git a/Cargo.lock b/Cargo.lock index e5982e380f35..58283379e147 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5087,6 +5087,11 @@ dependencies = [ [[package]] name = "wireguard-go-rs" version = "0.0.0" +dependencies = [ + "log", + "thiserror", + "zeroize", +] [[package]] name = "x25519-dalek" @@ -5122,9 +5127,9 @@ dependencies = [ [[package]] name = "zeroize" -version = "1.7.0" +version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "525b4ec142c6b68a2d10f01f7bbf6755599ca3f81ea53b8431b7dd348f5fdb2d" +checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" dependencies = [ "zeroize_derive", ] diff --git a/talpid-wireguard/src/connectivity_check.rs b/talpid-wireguard/src/connectivity_check.rs index a2b1651ca8b9..e04fb900d99f 100644 --- a/talpid-wireguard/src/connectivity_check.rs +++ b/talpid-wireguard/src/connectivity_check.rs @@ -5,9 +5,10 @@ use crate::{ use std::{ cmp, net::Ipv4Addr, - sync::{mpsc, Mutex, Weak}, + sync::{mpsc, Weak}, time::{Duration, Instant}, }; +use tokio::sync::Mutex; use super::{Tunnel, TunnelError}; @@ -214,8 +215,7 @@ impl ConnectivityMonitor { fn get_stats(&self) -> Option> { self.tunnel_handle .upgrade()? - .lock() - .ok()? + .blocking_lock() .as_ref() .and_then(|tunnel| match tunnel.get_tunnel_stats() { Ok(stats) if stats.is_empty() => { @@ -551,7 +551,7 @@ mod test { rx_bytes: 0, }, ); - let peers = Mutex::new(map); + let peers = std::sync::Mutex::new(map); Self { on_get_stats: Box::new(move || { let mut peers = peers.lock().unwrap(); @@ -746,7 +746,7 @@ mod test { rx_bytes: 0, }, ); - let tunnel_stats = Mutex::new(map); + let tunnel_stats = std::sync::Mutex::new(map); let pinger = MockPinger::default(); let (_tunnel_anchor, tunnel) = MockTunnel::new(move || { diff --git a/talpid-wireguard/src/lib.rs b/talpid-wireguard/src/lib.rs index d4ac66c18a1d..d2198ae2e759 100644 --- a/talpid-wireguard/src/lib.rs +++ b/talpid-wireguard/src/lib.rs @@ -144,7 +144,7 @@ impl Error { pub struct WireguardMonitor { runtime: tokio::runtime::Handle, /// Tunnel implementation - tunnel: Arc>>>, + tunnel: Arc>>>, /// Callback to signal tunnel events event_callback: EventCallback, close_msg_receiver: sync_mpsc::Receiver, @@ -306,7 +306,7 @@ impl WireguardMonitor { let (pinger_tx, pinger_rx) = sync_mpsc::channel(); let monitor = WireguardMonitor { runtime: args.runtime.clone(), - tunnel: Arc::new(Mutex::new(Some(tunnel))), + tunnel: Arc::new(AsyncMutex::new(Some(tunnel))), event_callback, close_msg_receiver: close_obfs_listener, pinger_stop_sender: pinger_tx, @@ -473,7 +473,7 @@ impl WireguardMonitor { #[allow(clippy::too_many_arguments)] async fn config_ephemeral_peers( - tunnel: &Arc>>>, + tunnel: &Arc>>>, config: &mut Config, retry_attempt: u32, on_event: F, @@ -579,7 +579,7 @@ impl WireguardMonitor { #[cfg(daita)] if config.daita { // Start local DAITA machines - let mut tunnel = tunnel.lock().unwrap(); + let mut tunnel = tunnel.lock().await; if let Some(tunnel) = tunnel.as_mut() { tunnel .start_daita() @@ -601,7 +601,7 @@ impl WireguardMonitor { /// Reconfigures the tunnel to use the provided config while potentially modifying the config /// and restarting the obfuscation provider. Returns the new config used by the new tunnel. async fn reconfigure_tunnel( - tunnel: &Arc>>>, + tunnel: &Arc>>>, mut config: Config, obfuscator: Arc>>, close_obfs_sender: sync_mpsc::Sender, @@ -625,11 +625,12 @@ impl WireguardMonitor { } } + let tunnel = tunnel.lock().await; + let set_config_future = tunnel - .lock() - .unwrap() .as_ref() .map(|tunnel| tunnel.set_config(config.clone())); + if let Some(f) = set_config_future { f.await .map_err(Error::TunnelError) @@ -847,7 +848,7 @@ impl WireguardMonitor { } fn stop_tunnel(&mut self) { - match self.tunnel.lock().expect("Tunnel lock poisoned").take() { + match self.tunnel.blocking_lock().take() { Some(tunnel) => { if let Err(e) = tunnel.stop() { log::error!("{}", e.display_chain_with_msg("Failed to stop tunnel")); @@ -1028,10 +1029,10 @@ pub(crate) trait Tunnel: Send { fn get_interface_name(&self) -> String; fn stop(self: Box) -> std::result::Result<(), TunnelError>; fn get_tunnel_stats(&self) -> std::result::Result; - fn set_config( - &self, + fn set_config<'a>( + &'a self, // TODO: should be &mut ?? _config: Config, - ) -> Pin> + Send>>; + ) -> Pin> + Send + 'a>>; #[cfg(daita)] /// A [`Tunnel`] capable of using DAITA. fn start_daita(&mut self) -> std::result::Result<(), TunnelError>; @@ -1055,10 +1056,19 @@ pub enum TunnelError { #[error("Failed to start wireguard tunnel")] FatalStartWireguardError, + /// Failed to start the wireguard tunnel. + #[error("Failed to start wireguard-go tunnel: {status}")] + StartWireguardError { + // TODO: consider doing a Box intead + /// Implementation-specific error code. + status: i32, + }, + /// Failed to tear down wireguard tunnel. - #[error("Failed to stop wireguard tunnel. Status: {status}")] + #[error("Failed to stop wireguard tunnel: {status}")] StopWireguardError { - /// Returned error code + // TODO: consider doing a Box intead + /// Implementation-specific error code. status: i32, }, @@ -1114,8 +1124,8 @@ pub enum TunnelError { /// Failed to receive DAITA event #[cfg(daita)] - #[error("Failed to receive DAITA event")] - DaitaReceiveEvent(i32), + #[error("Failed to start DAITA")] + StartDaita(#[source] Box), /// This tunnel does not support DAITA. #[cfg(daita)] diff --git a/talpid-wireguard/src/wireguard_go/daita.rs b/talpid-wireguard/src/wireguard_go/daita.rs deleted file mode 100644 index 67c074b8b3b9..000000000000 --- a/talpid-wireguard/src/wireguard_go/daita.rs +++ /dev/null @@ -1,46 +0,0 @@ -#![cfg(daita)] -use std::{ffi::CStr, io}; - -use talpid_types::net::wireguard::PublicKey; -use wireguard_go_rs::wgActivateDaita; - -/// Maximum number of events that can be stored in the underlying buffer -const EVENTS_CAPACITY: u32 = 1000; -/// Maximum number of actions that can be stored in the underlying buffer -const ACTIONS_CAPACITY: u32 = 1000; - -#[derive(Debug)] -pub struct Session { - _tunnel_handle: i32, -} - -impl Session { - /// Enable DAITA for an existing WireGuard interface. - pub(super) fn from_adapter( - tunnel_handle: i32, - peer_public_key: &PublicKey, - machines: &CStr, - ) -> io::Result { - // SAFETY: - // peer_public_key and machines lives for the duration of this function call. - - // TODO: ´machines` must be valid UTF-8 - let res = unsafe { - wgActivateDaita( - tunnel_handle, - peer_public_key.as_bytes().as_ptr(), - machines.as_ptr(), - EVENTS_CAPACITY, - ACTIONS_CAPACITY, - ) - }; - - if res != 0 { - // TODO: return error - panic!("Failed to activate DAITA on tunnel {tunnel_handle}, error code: {res}"); - } - Ok(Self { - _tunnel_handle: tunnel_handle, - }) - } -} diff --git a/talpid-wireguard/src/wireguard_go/mod.rs b/talpid-wireguard/src/wireguard_go/mod.rs index 1123bf3b7ed8..d3afc15a3e76 100644 --- a/talpid-wireguard/src/wireguard_go/mod.rs +++ b/talpid-wireguard/src/wireguard_go/mod.rs @@ -1,10 +1,8 @@ use ipnetwork::IpNetwork; #[cfg(daita)] use once_cell::sync::OnceCell; -#[cfg(daita)] -use std::{ffi::CString, fs, path::PathBuf}; use std::{ - ffi::{c_char, c_void, CStr}, + ffi::c_void, future::Future, net::IpAddr, os::unix::io::{AsRawFd, RawFd}, @@ -12,14 +10,12 @@ use std::{ pin::Pin, sync::{Arc, Mutex}, }; +#[cfg(daita)] +use std::{ffi::CString, fs, path::PathBuf}; #[cfg(target_os = "android")] use talpid_tunnel::tun_provider::Error as TunProviderError; use talpid_tunnel::tun_provider::{Tun, TunConfig, TunProvider}; use talpid_types::BoxedError; -use zeroize::Zeroize; - -#[cfg(daita)] -mod daita; use super::{ stats::{Stats, StatsMap}, @@ -27,10 +23,16 @@ use super::{ }; use crate::logging::{clean_up_logging, initialize_logging}; -use wireguard_go_rs::*; - const MAX_PREPARE_TUN_ATTEMPTS: usize = 4; +/// Maximum number of events that can be stored in the underlying buffer +#[cfg(daita)] +const DAITA_EVENTS_CAPACITY: u32 = 1000; + +/// Maximum number of actions that can be stored in the underlying buffer +#[cfg(daita)] +const DAITA_ACTIONS_CAPACITY: u32 = 1000; + type Result = std::result::Result; struct LoggingContext(u32); @@ -43,7 +45,7 @@ impl Drop for LoggingContext { pub struct WgGoTunnel { interface_name: String, - tunnel_handle: i32, + tunnel_handle: wireguard_go_rs::Tunnel, // holding on to the tunnel device and the log file ensures that the associated file handles // live long enough and get closed when the tunnel is stopped _tunnel_device: Tun, @@ -52,8 +54,6 @@ pub struct WgGoTunnel { #[cfg(target_os = "android")] tun_provider: Arc>, #[cfg(daita)] - daita_handle: Option, - #[cfg(daita)] resource_dir: PathBuf, #[cfg(daita)] config: Config, @@ -81,17 +81,14 @@ impl WgGoTunnel { #[cfg(not(target_os = "android"))] let mtu = config.mtu as isize; - let handle = unsafe { - wgTurnOn( - #[cfg(not(target_os = "android"))] - mtu, - wg_config_str.as_ptr() as _, - tunnel_fd, - Some(logging::wg_go_logging_callback), - logging_context.0 as *mut c_void, - ) - }; - check_wg_status(handle)?; + let handle = wireguard_go_rs::Tunnel::turn_on( + mtu, + &wg_config_str, + tunnel_fd, + Some(logging::wg_go_logging_callback), + logging_context.0 as *mut c_void, + ) + .map_err(|e| TunnelError::StartWireguardError { status: e.as_raw() })?; #[cfg(target_os = "android")] Self::bypass_tunnel_sockets(&mut tunnel_device, handle) @@ -107,8 +104,6 @@ impl WgGoTunnel { #[cfg(daita)] resource_dir: resource_dir.to_owned(), #[cfg(daita)] - daita_handle: None, - #[cfg(daita)] config: config.clone(), }) } @@ -153,14 +148,6 @@ impl WgGoTunnel { Ok(()) } - fn stop_tunnel(&mut self) -> Result<()> { - let status = unsafe { wgTurnOff(self.tunnel_handle) }; - if status < 0 { - return Err(TunnelError::StopWireguardError { status }); - } - Ok(()) - } - fn get_tunnel( tun_provider: Arc>, config: &Config, @@ -191,64 +178,37 @@ impl WgGoTunnel { } } -impl Drop for WgGoTunnel { - fn drop(&mut self) { - if let Err(e) = self.stop_tunnel() { - log::error!("Failed to stop tunnel: {}", e); - } - } -} - impl Tunnel for WgGoTunnel { fn get_interface_name(&self) -> String { self.interface_name.clone() } fn get_tunnel_stats(&self) -> Result { - let config_str = unsafe { - let ptr = wgGetConfig(self.tunnel_handle); - if ptr.is_null() { - log::error!("Failed to get config !"); - return Err(TunnelError::GetConfigError); - } - - CStr::from_ptr(ptr) - }; - - let result = - Stats::parse_config_str(config_str.to_str().expect("Go strings are always UTF-8")) - .map_err(|error| TunnelError::StatsError(BoxedError::new(error))); - unsafe { - // Zeroing out config string to not leave private key in memory. - let slice = std::slice::from_raw_parts_mut( - config_str.as_ptr() as *mut c_char, - config_str.to_bytes().len(), - ); - slice.zeroize(); - - wgFreePtr(config_str.as_ptr() as *mut c_void); - } - - result + self.tunnel_handle + .get_config(|cstr| { + Stats::parse_config_str(cstr.to_str().expect("Go strings are always UTF-8")) + }) + .ok_or(TunnelError::GetConfigError)? + .map_err(|error| TunnelError::StatsError(BoxedError::new(error))) } - fn stop(mut self: Box) -> Result<()> { - self.stop_tunnel() + fn stop(self: Box) -> Result<()> { + self.tunnel_handle + .turn_off() + .map_err(|e| TunnelError::StopWireguardError { status: e.as_raw() }) } - fn set_config( - &self, - config: Config, - ) -> Pin> + Send>> { - let wg_config_str = config.to_userspace_format(); - let handle = self.tunnel_handle; - #[cfg(target_os = "android")] - let tun_provider = self.tun_provider.clone(); + // TODO: should probably be &mut to guard against concurrency issues? + fn set_config(&self, config: Config) -> Pin> + Send + '_>> { Box::pin(async move { - let status = unsafe { wgSetConfig(handle, wg_config_str.as_ptr() as _) }; - if status != 0 { - return Err(TunnelError::SetConfigError); - } + let wg_config_str = config.to_userspace_format(); + + self.tunnel_handle + .set_config(&wg_config_str) + .map_err(|_| TunnelError::SetConfigError)?; + + #[cfg(target_os = "android")] + let tun_provider = self.tun_provider.clone(); // When reapplying the config, the endpoint socket may be discarded // and needs to be excluded again @@ -271,12 +231,6 @@ impl Tunnel for WgGoTunnel { #[cfg(daita)] fn start_daita(&mut self) -> Result<()> { - if let Some(_handle) = self.daita_handle.take() { - log::info!("Stopping previous DAITA machines"); - // let _ = handle.close(); - todo!("Closing existing DAITA instance") - } - static MAYBENOT_MACHINES: OnceCell = OnceCell::new(); let machines = MAYBENOT_MACHINES.get_or_try_init(|| { let path = self.resource_dir.join("maybenot_machines"); @@ -290,29 +244,19 @@ impl Tunnel for WgGoTunnel { log::info!("Initializing DAITA for wireguard device"); let peer_public_key = &self.config.entry_peer.public_key; - let session = daita::Session::from_adapter(self.tunnel_handle, peer_public_key, machines) - .expect("Wireguard-go should fetch current tunnel from ID"); - self.daita_handle = Some(session); + self.tunnel_handle + .activate_daita( + peer_public_key.as_bytes(), + machines, + DAITA_EVENTS_CAPACITY, + DAITA_ACTIONS_CAPACITY, + ) + .map_err(|e| TunnelError::StartDaita(Box::new(e)))?; Ok(()) } } -fn check_wg_status(wg_code: i32) -> Result<()> { - match wg_code { - ERROR_GENERAL_FAILURE => Err(TunnelError::FatalStartWireguardError), - ERROR_INTERMITTENT_FAILURE => Err(TunnelError::RecoverableStartWireguardError), - 0.. => Ok(()), - _ => { - log::error!("Unknown status code returned from wireguard-go"); - Err(TunnelError::FatalStartWireguardError) - } - } -} - -const ERROR_GENERAL_FAILURE: i32 = -1; -const ERROR_INTERMITTENT_FAILURE: i32 = -2; - mod stats { use super::{Stats, StatsMap}; diff --git a/wireguard-go-rs/Cargo.toml b/wireguard-go-rs/Cargo.toml index 19c397153c81..3725787bf6dc 100644 --- a/wireguard-go-rs/Cargo.toml +++ b/wireguard-go-rs/Cargo.toml @@ -3,3 +3,8 @@ name = "wireguard-go-rs" description = "Rust bindings to wireguard-go with DAITA support" edition = "2021" license.workspace = true + +[dependencies] +thiserror.workspace = true +log.workspace = true +zeroize = "1.8.1" diff --git a/wireguard-go-rs/libwg/libwg.go b/wireguard-go-rs/libwg/libwg.go index 1770c133bdd6..28f18d5cac78 100644 --- a/wireguard-go-rs/libwg/libwg.go +++ b/wireguard-go-rs/libwg/libwg.go @@ -25,10 +25,24 @@ import ( // FFI integer result codes const ( OK = C.int32_t(-iota) + + // Something went wrong. ERROR_GENERAL_FAILURE + + // Something went wrong, but trying again might help. ERROR_INTERMITTENT_FAILURE + + // A bad argument was provided to libwg. + ERROR_INVALID_ARGUMENT + + // The provided tunnel handle did not refer to an existing tunnel. ERROR_UNKNOWN_TUNNEL + + // The provided public key did not refer to an existing peer. ERROR_UNKNOWN_PEER + + // Something went wrong when enabling DAITA. + // TODO: consider removing this? should probably be replaced by more specific errors ERROR_ENABLE_DAITA ) @@ -81,7 +95,7 @@ func wgSetConfig(tunnelHandle int32, cSettings *C.char) C.int32_t { } if cSettings == nil { tunnel.Logger.Errorf("cSettings is null\n") - return ERROR_GENERAL_FAILURE + return ERROR_INVALID_ARGUMENT } settings := C.GoString(cSettings) diff --git a/wireguard-go-rs/libwg/libwg_android.go b/wireguard-go-rs/libwg/libwg_android.go index f40ac1c1dac6..c50d4a7471c0 100644 --- a/wireguard-go-rs/libwg/libwg_android.go +++ b/wireguard-go-rs/libwg/libwg_android.go @@ -35,7 +35,7 @@ func wgTurnOn(cSettings *C.char, fd int, logSink LogSink, logContext LogContext) if cSettings == nil { logger.Errorf("cSettings is null\n") - return ERROR_GENERAL_FAILURE + return ERROR_INVALID_ARGUMENT } settings := C.GoString(cSettings) diff --git a/wireguard-go-rs/libwg/libwg_default.go b/wireguard-go-rs/libwg/libwg_default.go index 6fe0d794a83c..5e5c111069b8 100644 --- a/wireguard-go-rs/libwg/libwg_default.go +++ b/wireguard-go-rs/libwg/libwg_default.go @@ -38,7 +38,7 @@ func wgTurnOn(mtu int, cSettings *C.char, fd int, logSink LogSink, logContext Lo if cSettings == nil { logger.Errorf("cSettings is null\n") - return ERROR_GENERAL_FAILURE + return ERROR_INVALID_ARGUMENT } settings := C.GoString(cSettings) diff --git a/wireguard-go-rs/src/lib.rs b/wireguard-go-rs/src/lib.rs index 37d5546a7763..042ca18770fb 100644 --- a/wireguard-go-rs/src/lib.rs +++ b/wireguard-go-rs/src/lib.rs @@ -1,70 +1,265 @@ #![cfg(unix)] + +use core::slice; +use std::{ + ffi::{c_char, c_void, CStr}, + mem::ManuallyDrop, +}; +use zeroize::Zeroize; + pub type Fd = std::os::unix::io::RawFd; -use std::ffi::{c_char, c_void}; + pub type WgLogLevel = u32; + pub type LoggingCallback = unsafe extern "system" fn(level: WgLogLevel, msg: *const c_char, context: *mut c_void); -extern "C" { - /// Creates a new wireguard tunnel, uses the specific interface name, MTU and file descriptors - /// for the tunnel device and logging. - /// - /// Positive return values are tunnel handles for this specific wireguard tunnel instance. - /// Negative return values signify errors. All error codes are opaque. - #[cfg(not(target_os = "android"))] - pub fn wgTurnOn( - mtu: isize, - settings: *const i8, - fd: Fd, - logging_callback: Option, - logging_context: *mut c_void, - ) -> i32; +/// A wireguard-go tunnel +pub struct Tunnel { + /// wireguard-go handle to the tunnel. + handle: i32, +} - // Android - #[cfg(target_os = "android")] - pub fn wgTurnOn( - settings: *const i8, - fd: Fd, +// NOTE: Must be kept in sync with libwg.go +// NOTE: must be kept in sync with `result_from_code` +// INVARIANT: Will aways be represented as a negative i32 +#[repr(i32)] +#[non_exhaustive] +#[derive(Clone, Copy, Debug, thiserror::Error)] +pub enum Error { + #[error("Something went wrong.")] + GeneralFailure = -1, + + #[error("Something went wrong, but trying again might help.")] + IntermittentFailure = -2, + + #[error("An argument you provided was invalid.")] + InvalidArgument = -3, + + #[error("The tunnel handle did not refer to an existing tunnel.")] + UnknownTunnel = -4, + + #[error("The provided public key did not refer to an existing peer.")] + UnknownPeer = -5, + + #[error("Something went wrong when enabling DAITA.")] + EnableDaita = -6, + + #[error("`libwg` provided an unknown error code. This is a bug.")] + Other = i32::MIN, +} + +impl Tunnel { + // TODO: this function is supposed to be a safe wrapper, but as clippy points out, + // the logging_context is a *mut, which may unsafely be dereferenced by the callback. + // I'd prefer NOT to mark this functon as unsafe though... + pub fn turn_on( + #[cfg(not(target_os = "android"))] + mtu: isize, + settings: &CStr, + device: Fd, logging_callback: Option, logging_context: *mut c_void, - ) -> i32; + ) -> Result { + // SAFETY: pointer is valid for the the lifetime of this function + let code = unsafe { + ffi::wgTurnOn( + #[cfg(not(target_os = "android"))] + mtu, - // Pass a handle that was created by wgTurnOn to stop a wireguard tunnel. - pub fn wgTurnOff(handle: i32) -> i32; + settings.as_ptr(), + device, + logging_callback, + logging_context, + ) + }; - // Returns the file descriptor of the tunnel IPv4 socket. - pub fn wgGetConfig(handle: i32) -> *mut c_char; + result_from_code(code)?; + Ok(Tunnel { handle: code }) + } - // Sets the config of the WireGuard interface. - pub fn wgSetConfig(handle: i32, settings: *const i8) -> i32; + pub fn turn_off(self) -> Result<(), Error> { + // we manually turn off the tunnel here, so wrap it in ManuallyDrop to prevent the Drop + // impl from doing the same. + let code = unsafe { ffi::wgTurnOff(self.handle) }; + let _ = ManuallyDrop::new(self); + result_from_code(code) + } - /// Activate DAITA for the specified peer. + /// Get the config of the WireGuard interface and make it available in the provided function. /// - /// `tunnel_handle` must come from [wgTurnOn]. `machines` is a string containing LF-separated - /// maybenot machines. + /// This takes a function to make sure the cstr get's zeroed and freed afterwards. + /// Returns `None` if the call to wgGetConfig returned nil. /// - /// # Safety: - /// - `peer_public_key` must point to a 32 byte array. - /// - `machines` must point to a null-terminated UTF-8 string. - /// - Neither pointer will be written to by `wgActivateDaita`. - /// - Neither pointer will be read from after `wgActivateDaita` has returned. + /// **NOTE:** You should take extra care to avoid copying any secrets from the config without zeroizing them afterwards. + // NOTE: this could return a guard type with a custom Drop impl instead, but me lazy. + pub fn get_config(&self, f: impl FnOnce(&CStr) -> T) -> Option { + // SAFETY: TODO: what to write here? + let ptr = unsafe { ffi::wgGetConfig(self.handle) }; + + if ptr.is_null() { + return None; + } + + + // contain any cast of ptr->ref within dedicated blocks to prevent accidents + let config_len: usize; + let t: T; + { + // SAFETY: we checked for null, and wgGetConfig promises that this is a valid cstr + let config = unsafe { CStr::from_ptr(ptr) }; + config_len = config.to_bytes().len(); + t = f(config); + } + + { + // SAFETY: + // we checked for null, and wgGetConfig promises that this is a valid cstr. + // config_len comes from the CStr above, so it should be good. + let config_bytes = unsafe { + slice::from_raw_parts_mut(ptr, config_len) + }; + config_bytes.zeroize(); + } + + // SAFETY: the pointer was created by wgGetConfig, and we are no longer using it. + unsafe { ffi::wgFreePtr(ptr.cast()) }; + + Some(t) + } + + pub fn set_config(&self, config: &CStr) -> Result<(), Error> { + // SAFETY: pointer is valid for the lifetime of this function. + let code = unsafe { ffi::wgSetConfig(self.handle, config.as_ptr()) }; + result_from_code(code) + } + #[cfg(daita)] - pub fn wgActivateDaita( - tunnel_handle: i32, - peer_public_key: *const u8, - machines: *const c_char, + pub fn activate_daita( + &self, + peer_public_key: &[u8; 32], + machines: &CStr, events_capacity: u32, actions_capacity: u32, - ) -> i32; + ) -> Result<(), Error> { + // SAFETY: pointers are valid for the lifetime of this function. + let code = unsafe { + ffi::wgActivateDaita( + self.handle, + peer_public_key.as_ptr(), + machines.as_ptr(), + events_capacity, + actions_capacity, + ) + }; - // Frees a pointer allocated by the go runtime - useful to free return value of wgGetConfig - pub fn wgFreePtr(ptr: *mut c_void); + result_from_code(code) + } - // Returns the file descriptor of the tunnel IPv4 socket. + /// Get the file descriptor of the tunnel IPv4 socket. #[cfg(target_os = "android")] - pub fn wgGetSocketV4(handle: i32) -> Fd; + pub fn get_socket_v4(&self) -> Fd { + unsafe { ffi::wgGetSocketV4(self.handle) } + } - // Returns the file descriptor of the tunnel IPv6 socket. + /// Get the file descriptor of the tunnel IPv6 socket. #[cfg(target_os = "android")] - pub fn wgGetSocketV6(handle: i32) -> Fd; + pub fn get_socket_v6(&self) -> Fd { + unsafe { ffi::wgGetSocketV6(self.handle) } + } +} + +impl Drop for Tunnel { + fn drop(&mut self) { + let code = unsafe { ffi::wgTurnOff(self.handle) }; + if let Err(e) = result_from_code(code) { + log::error!("Failed to stop wireguard-go tunnel,oerror_code={code} ({e:?})") + } + } +} + +fn result_from_code(code: i32) -> Result<(), Error> { + // NOTE: must be kept in sync with enum definition + Err(match code { + 0.. => return Ok(()), + -1 => Error::GeneralFailure, + -2 => Error::IntermittentFailure, + -3 => Error::UnknownTunnel, + -4 => Error::UnknownPeer, + -5 => Error::EnableDaita, + _ => Error::Other, + }) +} + +impl Error { + pub const fn as_raw(self) -> i32 { + self as i32 + } +} + +mod ffi { + use super::{Fd, LoggingCallback}; + use core::ffi::{c_char, c_void}; + + extern "C" { + /// Creates a new wireguard tunnel, uses the specific interface name, and file descriptors + /// for the tunnel device and logging. For targets other than android, this also takes an MTU value. + /// + /// Positive return values are tunnel handles for this specific wireguard tunnel instance. + /// Negative return values signify errors. All error codes are opaque. + pub fn wgTurnOn( + #[cfg(not(target_os = "android"))] + mtu: isize, + settings: *const i8, + fd: Fd, + logging_callback: Option, + logging_context: *mut c_void, + ) -> i32; + + /// Pass a handle that was created by wgTurnOn to stop a wireguard tunnel. + pub fn wgTurnOff(handle: i32) -> i32; + + /// Get the config of the WireGuard interface. + /// + /// # Safety: + /// - The function returns an owned pointer to a null-terminated UTF-8 string. + /// - The pointer may only be freed using [wgFreePtr]. + pub fn wgGetConfig(handle: i32) -> *mut c_char; + + /// Set the config of the WireGuard interface. + /// + /// # Safety: + /// - `settings` must point to a null-terminated UTF-8 string. + /// - The pointer will not be read from after `wgActivateDaita` has returned. + pub fn wgSetConfig(handle: i32, settings: *const i8) -> i32; + + /// Activate DAITA for the specified peer. + /// + /// `tunnel_handle` must come from [wgTurnOn]. `machines` is a string containing LF-separated + /// maybenot machines. + /// + /// # Safety: + /// - `peer_public_key` must point to a 32 byte array. + /// - `machines` must point to a null-terminated UTF-8 string. + /// - Neither pointer will be read from after `wgActivateDaita` has returned. + #[cfg(daita)] + pub fn wgActivateDaita( + tunnel_handle: i32, + peer_public_key: *const u8, + machines: *const c_char, + events_capacity: u32, + actions_capacity: u32, + ) -> i32; + + /// Free a pointer allocated by the go runtime - useful to free return value of wgGetConfig + pub fn wgFreePtr(ptr: *mut c_void); + + /// Get the file descriptor of the tunnel IPv4 socket. + #[cfg(target_os = "android")] + pub fn wgGetSocketV4(handle: i32) -> Fd; + + /// Get the file descriptor of the tunnel IPv6 socket. + #[cfg(target_os = "android")] + pub fn wgGetSocketV6(handle: i32) -> Fd; + } }