Skip to content

Commit

Permalink
Replace old waitForTunnelUp function
Browse files Browse the repository at this point in the history
After invoking VpnService.establish() we will get a tunnel file
descriptor that corresponds to the interface that was created. However,
this has no guarantee of the routing table beeing up to date, and we
might thus send traffic outside the tunnel. Previously this was done
through looking at the tunFd to see that traffic is sent to verify that
the routing table has changed. If no traffic is seen some traffic is
induced to a random IP address to ensure traffic can be seen. This new
implementation is slower but won't risk sending UDP traffic to a random
public address at the internet.
  • Loading branch information
Rawa committed Nov 12, 2024
1 parent 6de1244 commit 9c0127f
Show file tree
Hide file tree
Showing 6 changed files with 150 additions and 188 deletions.
1 change: 1 addition & 0 deletions android/lib/talpid/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ dependencies {
implementation(projects.lib.model)

implementation(libs.androidx.lifecycle.service)
implementation(libs.androidx.ktx)
implementation(libs.kermit)
implementation(libs.kotlin.stdlib)
implementation(libs.kotlinx.coroutines.android)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import android.net.ConnectivityManager.NetworkCallback
import android.net.Network
import android.net.NetworkCapabilities
import android.net.NetworkRequest
import co.touchlab.kermit.Logger
import kotlin.properties.Delegates.observable

class ConnectivityListener {
Expand All @@ -14,6 +15,7 @@ class ConnectivityListener {
private val callback =
object : NetworkCallback() {
override fun onAvailable(network: Network) {
Logger.d("Network $network")
availableNetworks.add(network)
isConnected = true
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,34 @@
package net.mullvad.talpid

import android.net.ConnectivityManager
import android.net.LinkProperties
import android.os.ParcelFileDescriptor
import android.system.Os.socket
import androidx.annotation.CallSuper
import androidx.core.content.getSystemService
import androidx.lifecycle.lifecycleScope
import co.touchlab.kermit.Logger
import java.net.Inet4Address
import java.net.Inet6Address
import java.net.InetAddress
import kotlin.properties.Delegates.observable
import kotlin.time.Duration
import kotlin.time.Duration.Companion.milliseconds
import kotlin.time.measureTimedValue
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.flow.SharingStarted
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.filterIsInstance
import kotlinx.coroutines.flow.filterNotNull
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.flow.stateIn
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.withTimeoutOrNull
import net.mullvad.talpid.model.CreateTunResult
import net.mullvad.talpid.model.TunConfig
import net.mullvad.talpid.util.NetworkEvent
import net.mullvad.talpid.util.TalpidSdkUtils.setMeteredIfSupported
import net.mullvad.talpid.util.defaultCallbackFlow

open class TalpidVpnService : LifecycleVpnService() {
private var activeTunStatus by
Expand All @@ -31,10 +50,21 @@ open class TalpidVpnService : LifecycleVpnService() {
// Used by JNI
val connectivityListener = ConnectivityListener()

private lateinit var defaultNetworkLinkProperties:
StateFlow<NetworkEvent.OnLinkPropertiesChanged?>

@CallSuper
override fun onCreate() {
super.onCreate()
connectivityListener.register(this)

val connectivityManager = getSystemService<ConnectivityManager>()!!

defaultNetworkLinkProperties =
connectivityManager
.defaultCallbackFlow()
.filterIsInstance<NetworkEvent.OnLinkPropertiesChanged>()
.stateIn(lifecycleScope, SharingStarted.Eagerly, null)
}

@CallSuper
Expand Down Expand Up @@ -94,7 +124,7 @@ open class TalpidVpnService : LifecycleVpnService() {
for (dnsServer in config.dnsServers) {
try {
addDnsServer(dnsServer)
} catch (exception: IllegalArgumentException) {
} catch (_: IllegalArgumentException) {
invalidDnsServerAddresses.add(dnsServer)
}
}
Expand Down Expand Up @@ -131,15 +161,19 @@ open class TalpidVpnService : LifecycleVpnService() {
return CreateTunResult.TunnelDeviceError
}

Logger.d("Vpn Interface Established")

if (vpnInterfaceFd == null) {
Logger.e("VpnInterface returned null")
return CreateTunResult.TunnelDeviceError
}

val tunFd = vpnInterfaceFd.detachFd()

waitForTunnelUp(tunFd, config.routes.any { route -> route.isIpv6 })
// Wait for android OS to respond back to us that the routes are setup so we don't send
// traffic before the routes are set up. Otherwise we might send traffic through the wrong
// interface
runBlocking { waitForRoutesWithTimeout(config) }

val tunFd = vpnInterfaceFd.detachFd()
if (invalidDnsServerAddresses.isNotEmpty()) {
return CreateTunResult.InvalidDnsServers(invalidDnsServerAddresses, tunFd)
}
Expand All @@ -151,19 +185,43 @@ open class TalpidVpnService : LifecycleVpnService() {
return protect(socket)
}

@OptIn(ExperimentalCoroutinesApi::class)
private suspend fun waitForRoutesWithTimeout(
config: TunConfig,
timeout: Duration = ROUTES_SETUP_TIMEOUT,
) {
val linkProperties =
withTimeoutOrNull(timeout = timeout) {
measureTimedValue {
defaultNetworkLinkProperties.filterNotNull().first {
it.linkProperties.matches(config)
}
}
.also { Logger.d("LinkProperties matching tunnel, took ${it.duration}") }
.value
}
if (linkProperties == null) {
Logger.w("Waiting for LinkProperties timed out")
}
}

// return true if LinkProperties matches the TunConfig
private fun LinkProperties.matches(tunConfig: TunConfig): Boolean =
linkAddresses.all { it.address in tunConfig.addresses }

private fun InetAddress.prefixLength(): Int =
when (this) {
is Inet4Address -> IPV4_PREFIX_LENGTH
is Inet6Address -> IPV6_PREFIX_LENGTH
else -> throw IllegalArgumentException("Invalid IP address (not IPv4 nor IPv6)")
}

private external fun waitForTunnelUp(tunFd: Int, isIpv6Enabled: Boolean)

companion object {
private const val FALLBACK_DUMMY_DNS_SERVER = "192.0.2.1"

private const val IPV4_PREFIX_LENGTH = 32
private const val IPV6_PREFIX_LENGTH = 128

private val ROUTES_SETUP_TIMEOUT: Duration = 400.milliseconds
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package net.mullvad.talpid.util

import android.net.ConnectivityManager
import android.net.ConnectivityManager.NetworkCallback
import android.net.LinkProperties
import android.net.Network
import android.net.NetworkCapabilities
import kotlinx.coroutines.channels.awaitClose
import kotlinx.coroutines.channels.trySendBlocking
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.callbackFlow

sealed interface NetworkEvent {
data class OnAvailable(val network: Network) : NetworkEvent

data object OnUnavailable : NetworkEvent

data class OnLinkPropertiesChanged(val network: Network, val linkProperties: LinkProperties) :
NetworkEvent

data class OnCapabilitiesChanged(
val network: Network,
val networkCapabilities: NetworkCapabilities,
) : NetworkEvent

data class OnBlockedStatusChanged(val network: Network, val blocked: Boolean) : NetworkEvent

data class OnLosing(val network: Network, val maxMsToLive: Int) : NetworkEvent

data class OnLost(val network: Network) : NetworkEvent
}

fun ConnectivityManager.defaultCallbackFlow(): Flow<NetworkEvent> =
callbackFlow<NetworkEvent> {
val callback =
object : NetworkCallback() {
override fun onLinkPropertiesChanged(
network: Network,
linkProperties: LinkProperties,
) {
super.onLinkPropertiesChanged(network, linkProperties)
trySendBlocking(NetworkEvent.OnLinkPropertiesChanged(network, linkProperties))
}

override fun onAvailable(network: Network) {
super.onAvailable(network)
trySendBlocking(NetworkEvent.OnAvailable(network))
}

override fun onCapabilitiesChanged(
network: Network,
networkCapabilities: NetworkCapabilities,
) {
super.onCapabilitiesChanged(network, networkCapabilities)
trySendBlocking(
NetworkEvent.OnCapabilitiesChanged(network, networkCapabilities)
)
}

override fun onBlockedStatusChanged(network: Network, blocked: Boolean) {
super.onBlockedStatusChanged(network, blocked)
trySendBlocking(NetworkEvent.OnBlockedStatusChanged(network, blocked))
}

override fun onLosing(network: Network, maxMsToLive: Int) {
super.onLosing(network, maxMsToLive)
trySendBlocking(NetworkEvent.OnLosing(network, maxMsToLive))
}

override fun onLost(network: Network) {
super.onLost(network)
trySendBlocking(NetworkEvent.OnLost(network))
}

override fun onUnavailable() {
super.onUnavailable()
trySendBlocking(NetworkEvent.OnUnavailable)
}
}
registerDefaultNetworkCallback(callback)

awaitClose { unregisterNetworkCallback(callback) }
}
1 change: 0 additions & 1 deletion mullvad-jni/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
mod api;
mod classes;
mod problem_report;
mod talpid_vpn_service;

use jnix::{
jni::{
Expand Down
Loading

0 comments on commit 9c0127f

Please sign in to comment.