diff --git a/android/lib/talpid/build.gradle.kts b/android/lib/talpid/build.gradle.kts index a5cd613de189..d3ec06276e5b 100644 --- a/android/lib/talpid/build.gradle.kts +++ b/android/lib/talpid/build.gradle.kts @@ -32,6 +32,7 @@ dependencies { implementation(projects.lib.model) implementation(libs.androidx.lifecycle.service) + implementation(libs.androidx.ktx) implementation(libs.kermit) implementation(libs.kotlin.stdlib) implementation(libs.kotlinx.coroutines.android) diff --git a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/ConnectivityListener.kt b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/ConnectivityListener.kt index edeec9a6fe9d..7433d6b14f02 100644 --- a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/ConnectivityListener.kt +++ b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/ConnectivityListener.kt @@ -6,6 +6,7 @@ import android.net.ConnectivityManager.NetworkCallback import android.net.Network import android.net.NetworkCapabilities import android.net.NetworkRequest +import co.touchlab.kermit.Logger import kotlin.properties.Delegates.observable class ConnectivityListener { @@ -14,6 +15,7 @@ class ConnectivityListener { private val callback = object : NetworkCallback() { override fun onAvailable(network: Network) { + Logger.d("Network $network") availableNetworks.add(network) isConnected = true } diff --git a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/TalpidVpnService.kt b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/TalpidVpnService.kt index 9470c88318b6..b2c171151a22 100644 --- a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/TalpidVpnService.kt +++ b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/TalpidVpnService.kt @@ -1,15 +1,34 @@ package net.mullvad.talpid +import android.net.ConnectivityManager +import android.net.LinkProperties import android.os.ParcelFileDescriptor +import android.system.Os.socket import androidx.annotation.CallSuper +import androidx.core.content.getSystemService +import androidx.lifecycle.lifecycleScope import co.touchlab.kermit.Logger import java.net.Inet4Address import java.net.Inet6Address import java.net.InetAddress import kotlin.properties.Delegates.observable +import kotlin.time.Duration +import kotlin.time.Duration.Companion.milliseconds +import kotlin.time.measureTimedValue +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.flow.SharingStarted +import kotlinx.coroutines.flow.StateFlow +import kotlinx.coroutines.flow.filterIsInstance +import kotlinx.coroutines.flow.filterNotNull +import kotlinx.coroutines.flow.first +import kotlinx.coroutines.flow.stateIn +import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.withTimeoutOrNull import net.mullvad.talpid.model.CreateTunResult import net.mullvad.talpid.model.TunConfig +import net.mullvad.talpid.util.NetworkEvent import net.mullvad.talpid.util.TalpidSdkUtils.setMeteredIfSupported +import net.mullvad.talpid.util.defaultCallbackFlow open class TalpidVpnService : LifecycleVpnService() { private var activeTunStatus by @@ -31,10 +50,21 @@ open class TalpidVpnService : LifecycleVpnService() { // Used by JNI val connectivityListener = ConnectivityListener() + private lateinit var defaultNetworkLinkProperties: + StateFlow + @CallSuper override fun onCreate() { super.onCreate() connectivityListener.register(this) + + val connectivityManager = getSystemService()!! + + defaultNetworkLinkProperties = + connectivityManager + .defaultCallbackFlow() + .filterIsInstance() + .stateIn(lifecycleScope, SharingStarted.Eagerly, null) } @CallSuper @@ -94,7 +124,7 @@ open class TalpidVpnService : LifecycleVpnService() { for (dnsServer in config.dnsServers) { try { addDnsServer(dnsServer) - } catch (exception: IllegalArgumentException) { + } catch (_: IllegalArgumentException) { invalidDnsServerAddresses.add(dnsServer) } } @@ -131,15 +161,19 @@ open class TalpidVpnService : LifecycleVpnService() { return CreateTunResult.TunnelDeviceError } + Logger.d("Vpn Interface Established") + if (vpnInterfaceFd == null) { Logger.e("VpnInterface returned null") return CreateTunResult.TunnelDeviceError } - val tunFd = vpnInterfaceFd.detachFd() - - waitForTunnelUp(tunFd, config.routes.any { route -> route.isIpv6 }) + // Wait for android OS to respond back to us that the routes are setup so we don't send + // traffic before the routes are set up. Otherwise we might send traffic through the wrong + // interface + runBlocking { waitForRoutesWithTimeout(config) } + val tunFd = vpnInterfaceFd.detachFd() if (invalidDnsServerAddresses.isNotEmpty()) { return CreateTunResult.InvalidDnsServers(invalidDnsServerAddresses, tunFd) } @@ -151,6 +185,30 @@ open class TalpidVpnService : LifecycleVpnService() { return protect(socket) } + @OptIn(ExperimentalCoroutinesApi::class) + private suspend fun waitForRoutesWithTimeout( + config: TunConfig, + timeout: Duration = ROUTES_SETUP_TIMEOUT, + ) { + val linkProperties = + withTimeoutOrNull(timeout = timeout) { + measureTimedValue { + defaultNetworkLinkProperties.filterNotNull().first { + it.linkProperties.matches(config) + } + } + .also { Logger.d("LinkProperties matching tunnel, took ${it.duration}") } + .value + } + if (linkProperties == null) { + Logger.w("Waiting for LinkProperties timed out") + } + } + + // return true if LinkProperties matches the TunConfig + private fun LinkProperties.matches(tunConfig: TunConfig): Boolean = + linkAddresses.all { it.address in tunConfig.addresses } + private fun InetAddress.prefixLength(): Int = when (this) { is Inet4Address -> IPV4_PREFIX_LENGTH @@ -158,12 +216,12 @@ open class TalpidVpnService : LifecycleVpnService() { else -> throw IllegalArgumentException("Invalid IP address (not IPv4 nor IPv6)") } - private external fun waitForTunnelUp(tunFd: Int, isIpv6Enabled: Boolean) - companion object { private const val FALLBACK_DUMMY_DNS_SERVER = "192.0.2.1" private const val IPV4_PREFIX_LENGTH = 32 private const val IPV6_PREFIX_LENGTH = 128 + + private val ROUTES_SETUP_TIMEOUT: Duration = 400.milliseconds } } diff --git a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/util/ConnectivityManagerExt.kt b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/util/ConnectivityManagerExt.kt new file mode 100644 index 000000000000..ace2ff42cacd --- /dev/null +++ b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/util/ConnectivityManagerExt.kt @@ -0,0 +1,83 @@ +package net.mullvad.talpid.util + +import android.net.ConnectivityManager +import android.net.ConnectivityManager.NetworkCallback +import android.net.LinkProperties +import android.net.Network +import android.net.NetworkCapabilities +import kotlinx.coroutines.channels.awaitClose +import kotlinx.coroutines.channels.trySendBlocking +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.callbackFlow + +sealed interface NetworkEvent { + data class OnAvailable(val network: Network) : NetworkEvent + + data object OnUnavailable : NetworkEvent + + data class OnLinkPropertiesChanged(val network: Network, val linkProperties: LinkProperties) : + NetworkEvent + + data class OnCapabilitiesChanged( + val network: Network, + val networkCapabilities: NetworkCapabilities, + ) : NetworkEvent + + data class OnBlockedStatusChanged(val network: Network, val blocked: Boolean) : NetworkEvent + + data class OnLosing(val network: Network, val maxMsToLive: Int) : NetworkEvent + + data class OnLost(val network: Network) : NetworkEvent +} + +fun ConnectivityManager.defaultCallbackFlow(): Flow = + callbackFlow { + val callback = + object : NetworkCallback() { + override fun onLinkPropertiesChanged( + network: Network, + linkProperties: LinkProperties, + ) { + super.onLinkPropertiesChanged(network, linkProperties) + trySendBlocking(NetworkEvent.OnLinkPropertiesChanged(network, linkProperties)) + } + + override fun onAvailable(network: Network) { + super.onAvailable(network) + trySendBlocking(NetworkEvent.OnAvailable(network)) + } + + override fun onCapabilitiesChanged( + network: Network, + networkCapabilities: NetworkCapabilities, + ) { + super.onCapabilitiesChanged(network, networkCapabilities) + trySendBlocking( + NetworkEvent.OnCapabilitiesChanged(network, networkCapabilities) + ) + } + + override fun onBlockedStatusChanged(network: Network, blocked: Boolean) { + super.onBlockedStatusChanged(network, blocked) + trySendBlocking(NetworkEvent.OnBlockedStatusChanged(network, blocked)) + } + + override fun onLosing(network: Network, maxMsToLive: Int) { + super.onLosing(network, maxMsToLive) + trySendBlocking(NetworkEvent.OnLosing(network, maxMsToLive)) + } + + override fun onLost(network: Network) { + super.onLost(network) + trySendBlocking(NetworkEvent.OnLost(network)) + } + + override fun onUnavailable() { + super.onUnavailable() + trySendBlocking(NetworkEvent.OnUnavailable) + } + } + registerDefaultNetworkCallback(callback) + + awaitClose { unregisterNetworkCallback(callback) } + } diff --git a/mullvad-jni/src/lib.rs b/mullvad-jni/src/lib.rs index 755cfce62341..bfd8e830f7e2 100644 --- a/mullvad-jni/src/lib.rs +++ b/mullvad-jni/src/lib.rs @@ -3,7 +3,6 @@ mod api; mod classes; mod problem_report; -mod talpid_vpn_service; use jnix::{ jni::{ diff --git a/mullvad-jni/src/talpid_vpn_service.rs b/mullvad-jni/src/talpid_vpn_service.rs deleted file mode 100644 index ea6928538a86..000000000000 --- a/mullvad-jni/src/talpid_vpn_service.rs +++ /dev/null @@ -1,181 +0,0 @@ -use ipnetwork::IpNetwork; -use jnix::jni::{ - objects::JObject, - sys::{jboolean, jint, JNI_FALSE}, - JNIEnv, -}; -use nix::sys::{ - select::{pselect, FdSet}, - time::{TimeSpec, TimeValLike}, -}; -use rand::{thread_rng, Rng}; -use std::{ - io, - net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket}, - os::unix::io::RawFd, - time::{Duration, Instant}, -}; -use talpid_types::ErrorExt; - -#[derive(Debug, thiserror::Error)] -enum Error { - #[error("Failed to verify the tunnel device")] - VerifyTunDevice(#[from] SendRandomDataError), - - #[error("Failed to select() on tunnel device")] - Select(#[from] nix::Error), - - #[error("Timed out while waiting for tunnel device to receive data")] - TunnelDeviceTimeout, -} - -#[no_mangle] -#[allow(non_snake_case)] -pub extern "system" fn Java_net_mullvad_talpid_TalpidVpnService_waitForTunnelUp( - _: JNIEnv<'_>, - _this: JObject<'_>, - tunFd: jint, - isIpv6Enabled: jboolean, -) { - let tun_fd = tunFd as RawFd; - let is_ipv6_enabled = isIpv6Enabled != JNI_FALSE; - - if let Err(error) = wait_for_tunnel_up(tun_fd, is_ipv6_enabled) { - log::error!( - "{}", - error.display_chain_with_msg("Failed to wait for tunnel device to be usable") - ); - } -} - -fn wait_for_tunnel_up(tun_fd: RawFd, is_ipv6_enabled: bool) -> Result<(), Error> { - let mut fd_set = FdSet::new(); - fd_set.insert(tun_fd); - let timeout = TimeSpec::microseconds(300); - const TIMEOUT: Duration = Duration::from_secs(60); - let start = Instant::now(); - while start.elapsed() < TIMEOUT { - // if tunnel device is ready to be read from, traffic is being routed through it - if pselect(None, Some(&mut fd_set), None, None, Some(&timeout), None)? > 0 { - return Ok(()); - } - // have to add tun_fd back into the bitset - fd_set.insert(tun_fd); - try_sending_random_udp(is_ipv6_enabled)?; - } - - Err(Error::TunnelDeviceTimeout) -} - -#[derive(Debug, thiserror::Error)] -enum SendRandomDataError { - #[error("Failed to bind an UDP socket")] - BindUdpSocket(#[source] io::Error), - - #[error("Failed to send random data through UDP socket")] - SendToUdpSocket(#[source] io::Error), -} - -fn try_sending_random_udp(is_ipv6_enabled: bool) -> Result<(), SendRandomDataError> { - let mut tried_ipv6 = false; - const TIMEOUT: Duration = Duration::from_millis(300); - let start = Instant::now(); - - while start.elapsed() < TIMEOUT { - // TODO: if we are to allow LAN on Android by changing the routes that are stuffed in - // TunConfig, then this should be revisited to be fair between IPv4 and IPv6 - let should_generate_ipv4 = !is_ipv6_enabled || tried_ipv6 || thread_rng().gen(); - let (bound_addr, random_public_addr) = random_socket_addrs(should_generate_ipv4); - - tried_ipv6 |= random_public_addr.ip().is_ipv6(); - - let socket = UdpSocket::bind(bound_addr).map_err(SendRandomDataError::BindUdpSocket)?; - match socket.send_to(&random_data(), random_public_addr) { - Ok(_) => return Ok(()), - // Always retry on IPv6 errors - Err(_) if random_public_addr.ip().is_ipv6() => continue, - Err(_err) if matches!(_err.raw_os_error(), Some(22) | Some(101)) => { - // Error code 101 - specified network is unreachable - // Error code 22 - specified address is not usable - continue; - } - Err(err) => return Err(SendRandomDataError::SendToUdpSocket(err)), - } - } - Ok(()) -} - -fn random_data() -> Vec { - let mut buf = vec![0u8; thread_rng().gen_range(17..214)]; - thread_rng().fill(buf.as_mut_slice()); - buf -} - -/// Returns a random local and public destination socket address. -/// If `ipv4` is true, then IPv4 addresses will be returned. Otherwise, IPv6 addresses will be -/// returned. -fn random_socket_addrs(ipv4: bool) -> (SocketAddr, SocketAddr) { - loop { - let rand_port = thread_rng().gen(); - let (local_addr, rand_dest_addr) = if ipv4 { - let mut ipv4_bytes = [0u8; 4]; - thread_rng().fill(&mut ipv4_bytes); - ( - SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0), - SocketAddr::new(IpAddr::from(ipv4_bytes), rand_port), - ) - } else { - let mut ipv6_bytes = [0u8; 16]; - thread_rng().fill(&mut ipv6_bytes); - ( - SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), 0), - SocketAddr::new(IpAddr::from(ipv6_bytes), rand_port), - ) - }; - - // TODO: once https://github.com/rust-lang/rust/issues/27709 is resolved, please use - // `is_global()` to check if a new address should be attempted. - if !is_public_ip(rand_dest_addr.ip()) { - continue; - } - - return (local_addr, rand_dest_addr); - } -} - -fn is_public_ip(addr: IpAddr) -> bool { - match addr { - IpAddr::V4(ipv4) => { - // 0.x.x.x is not a publicly routable address - if ipv4.octets()[0] == 0u8 { - return false; - } - } - IpAddr::V6(ipv6) => { - if ipv6.segments()[0] == 0u16 { - return false; - } - } - } - // A non-exhaustive list of non-public subnets - let publicly_unroutable_subnets: Vec = vec![ - // IPv4 local networks - "10.0.0.0/8".parse().unwrap(), - "172.16.0.0/12".parse().unwrap(), - "192.168.0.0/16".parse().unwrap(), - // IPv4 non-forwardable network - "169.254.0.0/16".parse().unwrap(), - "192.0.0.0/8".parse().unwrap(), - // Documentation networks - "192.0.2.0/24".parse().unwrap(), - "198.51.100.0/24".parse().unwrap(), - "203.0.113.0/24".parse().unwrap(), - // IPv6 publicly unroutable networks - "fc00::/7".parse().unwrap(), - "fe80::/10".parse().unwrap(), - ]; - - !publicly_unroutable_subnets - .iter() - .any(|net| net.contains(addr)) -}