diff --git a/android/lib/model/src/main/kotlin/net/mullvad/mullvadvpn/model/RelayConstraints.kt b/android/lib/model/src/main/kotlin/net/mullvad/mullvadvpn/model/RelayConstraints.kt index 031b09bacecc..82c8b26648c8 100644 --- a/android/lib/model/src/main/kotlin/net/mullvad/mullvadvpn/model/RelayConstraints.kt +++ b/android/lib/model/src/main/kotlin/net/mullvad/mullvadvpn/model/RelayConstraints.kt @@ -2,11 +2,13 @@ package net.mullvad.mullvadvpn.model import android.os.Parcelable import kotlinx.parcelize.Parcelize +import net.mullvad.talpid.net.TunnelType @Parcelize data class RelayConstraints( val location: Constraint, val providers: Constraint, val ownership: Constraint, + val tunnelProtocol: Constraint, val wireguardConstraints: WireguardConstraints, ) : Parcelable diff --git a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/net/TunnelType.kt b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/net/TunnelType.kt new file mode 100644 index 000000000000..1cce54e09d30 --- /dev/null +++ b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/net/TunnelType.kt @@ -0,0 +1,10 @@ +package net.mullvad.talpid.net + +import android.os.Parcelable +import kotlinx.parcelize.Parcelize + +@Parcelize +enum class TunnelType : Parcelable { + OpenVpn, + Wireguard +} diff --git a/android/service/src/main/kotlin/net/mullvad/mullvadvpn/service/endpoint/RelayListListener.kt b/android/service/src/main/kotlin/net/mullvad/mullvadvpn/service/endpoint/RelayListListener.kt index 2f18c090640a..7a5a604d2278 100644 --- a/android/service/src/main/kotlin/net/mullvad/mullvadvpn/service/endpoint/RelayListListener.kt +++ b/android/service/src/main/kotlin/net/mullvad/mullvadvpn/service/endpoint/RelayListListener.kt @@ -17,6 +17,7 @@ import net.mullvad.mullvadvpn.model.RelayList import net.mullvad.mullvadvpn.model.RelaySettings import net.mullvad.mullvadvpn.model.WireguardConstraints import net.mullvad.mullvadvpn.service.MullvadDaemon +import net.mullvad.talpid.net.TunnelType class RelayListListener( endpoint: ServiceEndpoint, @@ -51,7 +52,7 @@ class RelayListListener( LocationConstraint.Location(request.relayLocation) ) ) - daemon.await().setRelaySettings(RelaySettings.Normal(update)) + updateRelayConstraints(update) } } @@ -62,7 +63,7 @@ class RelayListListener( val update = getCurrentRelayConstraints() .copy(wireguardConstraints = request.wireguardConstraints) - daemon.await().setRelaySettings(RelaySettings.Normal(update)) + updateRelayConstraints(update) } } @@ -79,7 +80,7 @@ class RelayListListener( val update = getCurrentRelayConstraints() .copy(ownership = request.ownership, providers = request.providers) - daemon.await().setRelaySettings(RelaySettings.Normal(update)) + updateRelayConstraints(update) } } } @@ -89,6 +90,17 @@ class RelayListListener( scope.cancel() } + private suspend fun updateRelayConstraints(update: RelayConstraints) { + daemon + .await() + .setRelaySettings( + RelaySettings.Normal( + // Force Wireguard protocol + update.copy(tunnelProtocol = Constraint.Only(TunnelType.Wireguard)) + ) + ) + } + private fun setUpListener(daemon: MullvadDaemon) { daemon.onRelayListChange = { relayLocations -> relayList = relayLocations } } @@ -109,6 +121,8 @@ class RelayListListener( location = Constraint.Any(), providers = Constraint.Any(), ownership = Constraint.Any(), + // Force Wireguard protocol + tunnelProtocol = Constraint.Only(TunnelType.Wireguard), wireguardConstraints = WireguardConstraints(Constraint.Any()) ) } diff --git a/mullvad-jni/src/classes.rs b/mullvad-jni/src/classes.rs index 454e0e6c6281..59d3ff1e0450 100644 --- a/mullvad-jni/src/classes.rs +++ b/mullvad-jni/src/classes.rs @@ -76,6 +76,7 @@ pub const CLASSES: &[&str] = &[ "net/mullvad/talpid/net/TunnelEndpoint", "net/mullvad/talpid/net/ObfuscationEndpoint", "net/mullvad/talpid/net/ObfuscationType", + "net/mullvad/talpid/net/TunnelType", "net/mullvad/talpid/tun_provider/InetNetwork", "net/mullvad/talpid/tun_provider/TunConfig", "net/mullvad/talpid/tunnel/ActionAfterDisconnect", diff --git a/mullvad-types/src/relay_constraints.rs b/mullvad-types/src/relay_constraints.rs index 6645fb493d21..827a6cd4ee37 100644 --- a/mullvad-types/src/relay_constraints.rs +++ b/mullvad-types/src/relay_constraints.rs @@ -380,7 +380,6 @@ pub struct RelayConstraints { pub location: Constraint, pub providers: Constraint, pub ownership: Constraint, - #[cfg_attr(target_os = "android", jnix(skip))] pub tunnel_protocol: Constraint, pub wireguard_constraints: WireguardConstraints, #[cfg_attr(target_os = "android", jnix(skip))] @@ -467,10 +466,24 @@ where let ownership: Constraint = Constraint::from_java(env, object_ownership); - let object_wireguard_constraints = env + let object_tunnel_protocol = env .call_method( object, "component4", + "()Lnet/mullvad/mullvadvpn/model/Constraint;", + &[], + ) + .expect("missing RelayConstraints.tunnel_protocol") + .l() + .expect("RelayConstraints.tunnel_protocol did not return an object"); + + let tunnel_protocol: Constraint = + Constraint::from_java(env, object_tunnel_protocol); + + let object_wireguard_constraints = env + .call_method( + object, + "component5", "()Lnet/mullvad/mullvadvpn/model/WireguardConstraints;", &[], ) @@ -485,6 +498,7 @@ where location, providers, ownership, + tunnel_protocol, wireguard_constraints, ..Default::default() } diff --git a/talpid-types/src/net/mod.rs b/talpid-types/src/net/mod.rs index cd8d5987a912..0dc4600c1198 100644 --- a/talpid-types/src/net/mod.rs +++ b/talpid-types/src/net/mod.rs @@ -1,5 +1,5 @@ #[cfg(target_os = "android")] -use jnix::IntoJava; +use jnix::{FromJava, IntoJava}; use obfuscation::ObfuscatorConfig; use serde::{Deserialize, Serialize}; #[cfg(windows)] @@ -109,6 +109,8 @@ impl From for TunnelParameters { /// The tunnel protocol used by a [`TunnelEndpoint`]. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[cfg_attr(target_os = "android", derive(IntoJava, FromJava))] +#[cfg_attr(target_os = "android", jnix(package = "net.mullvad.talpid.net"))] #[serde(rename = "tunnel_type")] pub enum TunnelType { #[serde(rename = "openvpn")]