Skip to content

Commit

Permalink
Set tunnel protocol explicitly to Wireguard
Browse files Browse the repository at this point in the history
  • Loading branch information
Pururun committed Dec 8, 2023
1 parent 0313af5 commit b85daaa
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@ package net.mullvad.mullvadvpn.model

import android.os.Parcelable
import kotlinx.parcelize.Parcelize
import net.mullvad.talpid.net.TunnelType

@Parcelize
data class RelayConstraints(
val location: Constraint<LocationConstraint>,
val providers: Constraint<Providers>,
val ownership: Constraint<Ownership>,
val tunnelProtocol: Constraint<TunnelType>,
val wireguardConstraints: WireguardConstraints,
) : Parcelable
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package net.mullvad.talpid.net

import android.os.Parcelable
import kotlinx.parcelize.Parcelize

@Parcelize
enum class TunnelType : Parcelable {
OpenVpn,
Wireguard
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import net.mullvad.mullvadvpn.model.RelayList
import net.mullvad.mullvadvpn.model.RelaySettings
import net.mullvad.mullvadvpn.model.WireguardConstraints
import net.mullvad.mullvadvpn.service.MullvadDaemon
import net.mullvad.talpid.net.TunnelType

class RelayListListener(
endpoint: ServiceEndpoint,
Expand Down Expand Up @@ -51,7 +52,7 @@ class RelayListListener(
LocationConstraint.Location(request.relayLocation)
)
)
daemon.await().setRelaySettings(RelaySettings.Normal(update))
updateRelayConstraints(update)
}
}

Expand All @@ -62,7 +63,7 @@ class RelayListListener(
val update =
getCurrentRelayConstraints()
.copy(wireguardConstraints = request.wireguardConstraints)
daemon.await().setRelaySettings(RelaySettings.Normal(update))
updateRelayConstraints(update)
}
}

Expand All @@ -79,7 +80,7 @@ class RelayListListener(
val update =
getCurrentRelayConstraints()
.copy(ownership = request.ownership, providers = request.providers)
daemon.await().setRelaySettings(RelaySettings.Normal(update))
updateRelayConstraints(update)
}
}
}
Expand All @@ -89,6 +90,17 @@ class RelayListListener(
scope.cancel()
}

private suspend fun updateRelayConstraints(update: RelayConstraints) {
daemon
.await()
.setRelaySettings(
RelaySettings.Normal(
// Force Wireguard protocol
update.copy(tunnelProtocol = Constraint.Only(TunnelType.Wireguard))
)
)
}

private fun setUpListener(daemon: MullvadDaemon) {
daemon.onRelayListChange = { relayLocations -> relayList = relayLocations }
}
Expand All @@ -109,6 +121,8 @@ class RelayListListener(
location = Constraint.Any(),
providers = Constraint.Any(),
ownership = Constraint.Any(),
// Force Wireguard protocol
tunnelProtocol = Constraint.Only(TunnelType.Wireguard),
wireguardConstraints = WireguardConstraints(Constraint.Any())
)
}
Expand Down
1 change: 1 addition & 0 deletions mullvad-jni/src/classes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ pub const CLASSES: &[&str] = &[
"net/mullvad/talpid/net/TunnelEndpoint",
"net/mullvad/talpid/net/ObfuscationEndpoint",
"net/mullvad/talpid/net/ObfuscationType",
"net/mullvad/talpid/net/TunnelType",
"net/mullvad/talpid/tun_provider/InetNetwork",
"net/mullvad/talpid/tun_provider/TunConfig",
"net/mullvad/talpid/tunnel/ActionAfterDisconnect",
Expand Down
18 changes: 16 additions & 2 deletions mullvad-types/src/relay_constraints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,6 @@ pub struct RelayConstraints {
pub location: Constraint<LocationConstraint>,
pub providers: Constraint<Providers>,
pub ownership: Constraint<Ownership>,
#[cfg_attr(target_os = "android", jnix(skip))]
pub tunnel_protocol: Constraint<TunnelType>,
pub wireguard_constraints: WireguardConstraints,
#[cfg_attr(target_os = "android", jnix(skip))]
Expand Down Expand Up @@ -467,10 +466,24 @@ where

let ownership: Constraint<Ownership> = Constraint::from_java(env, object_ownership);

let object_wireguard_constraints = env
let object_tunnel_protocol = env
.call_method(
object,
"component4",
"()Lnet/mullvad/mullvadvpn/model/Constraint;",
&[],
)
.expect("missing RelayConstraints.tunnel_protocol")
.l()
.expect("RelayConstraints.tunnel_protocol did not return an object");

let tunnel_protocol: Constraint<TunnelType> =
Constraint::from_java(env, object_tunnel_protocol);

let object_wireguard_constraints = env
.call_method(
object,
"component5",
"()Lnet/mullvad/mullvadvpn/model/WireguardConstraints;",
&[],
)
Expand All @@ -485,6 +498,7 @@ where
location,
providers,
ownership,
tunnel_protocol,
wireguard_constraints,
..Default::default()
}
Expand Down
4 changes: 3 additions & 1 deletion talpid-types/src/net/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#[cfg(target_os = "android")]
use jnix::IntoJava;
use jnix::{FromJava, IntoJava};
use obfuscation::ObfuscatorConfig;
use serde::{Deserialize, Serialize};
#[cfg(windows)]
Expand Down Expand Up @@ -109,6 +109,8 @@ impl From<openvpn::TunnelParameters> for TunnelParameters {

/// The tunnel protocol used by a [`TunnelEndpoint`].
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[cfg_attr(target_os = "android", derive(IntoJava, FromJava))]
#[cfg_attr(target_os = "android", jnix(package = "net.mullvad.talpid.net"))]
#[serde(rename = "tunnel_type")]
pub enum TunnelType {
#[serde(rename = "openvpn")]
Expand Down

0 comments on commit b85daaa

Please sign in to comment.