Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace old waitForTunnelUp function #7155

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions android/lib/talpid/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ dependencies {
implementation(projects.lib.model)

implementation(libs.androidx.lifecycle.service)
implementation(libs.androidx.ktx)
implementation(libs.kermit)
implementation(libs.kotlin.stdlib)
implementation(libs.kotlinx.coroutines.android)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import android.net.ConnectivityManager.NetworkCallback
import android.net.Network
import android.net.NetworkCapabilities
import android.net.NetworkRequest
import co.touchlab.kermit.Logger
import kotlin.properties.Delegates.observable

class ConnectivityListener {
Expand All @@ -14,6 +15,7 @@ class ConnectivityListener {
private val callback =
object : NetworkCallback() {
override fun onAvailable(network: Network) {
Logger.d("Network $network")
availableNetworks.add(network)
isConnected = true
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,34 @@
package net.mullvad.talpid

import android.net.ConnectivityManager
import android.net.LinkProperties
import android.os.ParcelFileDescriptor
import android.system.Os.socket
import androidx.annotation.CallSuper
import androidx.core.content.getSystemService
import androidx.lifecycle.lifecycleScope
import co.touchlab.kermit.Logger
import java.net.Inet4Address
import java.net.Inet6Address
import java.net.InetAddress
import kotlin.properties.Delegates.observable
import kotlin.time.Duration
import kotlin.time.Duration.Companion.milliseconds
import kotlin.time.measureTimedValue
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.flow.SharingStarted
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.filterIsInstance
import kotlinx.coroutines.flow.filterNotNull
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.flow.stateIn
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.withTimeoutOrNull
import net.mullvad.talpid.model.CreateTunResult
import net.mullvad.talpid.model.TunConfig
import net.mullvad.talpid.util.NetworkEvent
import net.mullvad.talpid.util.TalpidSdkUtils.setMeteredIfSupported
import net.mullvad.talpid.util.defaultCallbackFlow

open class TalpidVpnService : LifecycleVpnService() {
private var activeTunStatus by
Expand All @@ -31,10 +50,21 @@ open class TalpidVpnService : LifecycleVpnService() {
// Used by JNI
val connectivityListener = ConnectivityListener()

private lateinit var defaultNetworkLinkProperties:
StateFlow<NetworkEvent.OnLinkPropertiesChanged?>

@CallSuper
override fun onCreate() {
super.onCreate()
connectivityListener.register(this)

val connectivityManager = getSystemService<ConnectivityManager>()!!

defaultNetworkLinkProperties =
connectivityManager
.defaultCallbackFlow()
.filterIsInstance<NetworkEvent.OnLinkPropertiesChanged>()
.stateIn(lifecycleScope, SharingStarted.Eagerly, null)
}

@CallSuper
Expand Down Expand Up @@ -94,7 +124,7 @@ open class TalpidVpnService : LifecycleVpnService() {
for (dnsServer in config.dnsServers) {
try {
addDnsServer(dnsServer)
} catch (exception: IllegalArgumentException) {
} catch (_: IllegalArgumentException) {
invalidDnsServerAddresses.add(dnsServer)
}
}
Expand Down Expand Up @@ -131,15 +161,19 @@ open class TalpidVpnService : LifecycleVpnService() {
return CreateTunResult.TunnelDeviceError
}

Logger.d("Vpn Interface Established")

if (vpnInterfaceFd == null) {
Logger.e("VpnInterface returned null")
return CreateTunResult.TunnelDeviceError
}

val tunFd = vpnInterfaceFd.detachFd()

waitForTunnelUp(tunFd, config.routes.any { route -> route.isIpv6 })
// Wait for android OS to respond back to us that the routes are setup so we don't send
// traffic before the routes are set up. Otherwise we might send traffic through the wrong
// interface
runBlocking { waitForRoutesWithTimeout(config) }

val tunFd = vpnInterfaceFd.detachFd()
if (invalidDnsServerAddresses.isNotEmpty()) {
return CreateTunResult.InvalidDnsServers(invalidDnsServerAddresses, tunFd)
}
Expand All @@ -151,19 +185,43 @@ open class TalpidVpnService : LifecycleVpnService() {
return protect(socket)
}

@OptIn(ExperimentalCoroutinesApi::class)
private suspend fun waitForRoutesWithTimeout(
config: TunConfig,
timeout: Duration = ROUTES_SETUP_TIMEOUT,
) {
val linkProperties =
withTimeoutOrNull(timeout = timeout) {
measureTimedValue {
defaultNetworkLinkProperties.filterNotNull().first {
it.linkProperties.matches(config)
}
}
.also { Logger.d("LinkProperties matching tunnel, took ${it.duration}") }
.value
}
if (linkProperties == null) {
Logger.w("Waiting for LinkProperties timed out")
}
}

// return true if LinkProperties matches the TunConfig
private fun LinkProperties.matches(tunConfig: TunConfig): Boolean =
linkAddresses.all { it.address in tunConfig.addresses }

private fun InetAddress.prefixLength(): Int =
when (this) {
is Inet4Address -> IPV4_PREFIX_LENGTH
is Inet6Address -> IPV6_PREFIX_LENGTH
else -> throw IllegalArgumentException("Invalid IP address (not IPv4 nor IPv6)")
}

private external fun waitForTunnelUp(tunFd: Int, isIpv6Enabled: Boolean)

companion object {
private const val FALLBACK_DUMMY_DNS_SERVER = "192.0.2.1"

private const val IPV4_PREFIX_LENGTH = 32
private const val IPV6_PREFIX_LENGTH = 128

private val ROUTES_SETUP_TIMEOUT: Duration = 400.milliseconds
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package net.mullvad.talpid.util

import android.net.ConnectivityManager
import android.net.ConnectivityManager.NetworkCallback
import android.net.LinkProperties
import android.net.Network
import android.net.NetworkCapabilities
import kotlinx.coroutines.channels.awaitClose
import kotlinx.coroutines.channels.trySendBlocking
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.callbackFlow

sealed interface NetworkEvent {
data class OnAvailable(val network: Network) : NetworkEvent

data object OnUnavailable : NetworkEvent

data class OnLinkPropertiesChanged(val network: Network, val linkProperties: LinkProperties) :
NetworkEvent

data class OnCapabilitiesChanged(
val network: Network,
val networkCapabilities: NetworkCapabilities,
) : NetworkEvent

data class OnBlockedStatusChanged(val network: Network, val blocked: Boolean) : NetworkEvent

data class OnLosing(val network: Network, val maxMsToLive: Int) : NetworkEvent

data class OnLost(val network: Network) : NetworkEvent
}

fun ConnectivityManager.defaultCallbackFlow(): Flow<NetworkEvent> =
callbackFlow<NetworkEvent> {
val callback =
object : NetworkCallback() {
override fun onLinkPropertiesChanged(
network: Network,
linkProperties: LinkProperties,
) {
super.onLinkPropertiesChanged(network, linkProperties)
trySendBlocking(NetworkEvent.OnLinkPropertiesChanged(network, linkProperties))
}

override fun onAvailable(network: Network) {
super.onAvailable(network)
trySendBlocking(NetworkEvent.OnAvailable(network))
}

override fun onCapabilitiesChanged(
network: Network,
networkCapabilities: NetworkCapabilities,
) {
super.onCapabilitiesChanged(network, networkCapabilities)
trySendBlocking(
NetworkEvent.OnCapabilitiesChanged(network, networkCapabilities)
)
}

override fun onBlockedStatusChanged(network: Network, blocked: Boolean) {
super.onBlockedStatusChanged(network, blocked)
trySendBlocking(NetworkEvent.OnBlockedStatusChanged(network, blocked))
}

override fun onLosing(network: Network, maxMsToLive: Int) {
super.onLosing(network, maxMsToLive)
trySendBlocking(NetworkEvent.OnLosing(network, maxMsToLive))
}

override fun onLost(network: Network) {
super.onLost(network)
trySendBlocking(NetworkEvent.OnLost(network))
}

override fun onUnavailable() {
super.onUnavailable()
trySendBlocking(NetworkEvent.OnUnavailable)
}
}
registerDefaultNetworkCallback(callback)

awaitClose { unregisterNetworkCallback(callback) }
}
1 change: 0 additions & 1 deletion mullvad-jni/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
mod api;
mod classes;
mod problem_report;
mod talpid_vpn_service;

use jnix::{
jni::{
Expand Down
Loading
Loading