diff --git a/Cargo.lock b/Cargo.lock index 80776af1992..36ee61dd332 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1676,6 +1676,9 @@ checksum = "d93bd0ebf93d61d6332d3c09a96e97975968a44e19a64c947bde06e6baff383f" dependencies = [ "futures-core", "readlock", + "readlock-tokio", + "tokio", + "tokio-util", "tracing", ] @@ -4540,6 +4543,15 @@ version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "072cfe5b1d2dcd38d20e18f85e9c9978b6cc08f0b373e9f1fff1541335622974" +[[package]] +name = "readlock-tokio" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "867fac64d07214a87e5cf4e88b4ce855844a1cea243534392377d1ac2c911653" +dependencies = [ + "tokio", +] + [[package]] name = "redox_syscall" version = "0.5.3" diff --git a/bindings/matrix-sdk-ffi/src/room.rs b/bindings/matrix-sdk-ffi/src/room.rs index fd0f6aef3fa..c1233b51151 100644 --- a/bindings/matrix-sdk-ffi/src/room.rs +++ b/bindings/matrix-sdk-ffi/src/room.rs @@ -1,7 +1,7 @@ use std::{collections::HashMap, pin::pin, sync::Arc}; use anyhow::{Context, Result}; -use futures_util::StreamExt; +use futures_util::{pin_mut, StreamExt}; use matrix_sdk::{ crypto::LocalTrust, event_cache::paginator::PaginatorError, @@ -911,6 +911,108 @@ impl Room { room_event_cache.clear().await?; Ok(()) } + + /// Subscribes to requests to join this room (knock member events), using a + /// `listener` to be notified of the changes. + /// + /// The current requests to join the room will be emitted immediately + /// when subscribing, along with a [`TaskHandle`] to cancel the + /// subscription. + pub async fn subscribe_to_knock_requests( + self: Arc, + listener: Box, + ) -> Result, ClientError> { + let stream = self.inner.subscribe_to_knock_requests().await?; + + let handle = Arc::new(TaskHandle::new(RUNTIME.spawn(async move { + pin_mut!(stream); + while let Some(requests) = stream.next().await { + listener.call(requests.into_iter().map(Into::into).collect()); + } + }))); + + Ok(handle) + } +} + +impl From for KnockRequest { + fn from(request: matrix_sdk::room::knock_requests::KnockRequest) -> Self { + Self { + event_id: request.event_id.to_string(), + user_id: request.member_info.user_id.to_string(), + room_id: request.room_id().to_string(), + display_name: request.member_info.display_name.clone(), + avatar_url: request.member_info.avatar_url.as_ref().map(|url| url.to_string()), + reason: request.member_info.reason.clone(), + timestamp: request.timestamp.map(|ts| ts.into()), + is_seen: request.is_seen, + actions: Arc::new(KnockRequestActions { inner: request }), + } + } +} + +/// A listener for receiving new requests to a join a room. +#[matrix_sdk_ffi_macros::export(callback_interface)] +pub trait KnockRequestsListener: Send + Sync { + fn call(&self, join_requests: Vec); +} + +/// An FFI representation of a request to join a room. +#[derive(Debug, Clone, uniffi::Record)] +pub struct KnockRequest { + /// The event id of the event that contains the `knock` membership change. + pub event_id: String, + /// The user id of the user who's requesting to join the room. + pub user_id: String, + /// The room id of the room whose access was requested. + pub room_id: String, + /// The optional display name of the user who's requesting to join the room. + pub display_name: Option, + /// The optional avatar url of the user who's requesting to join the room. + pub avatar_url: Option, + /// An optional reason why the user wants join the room. + pub reason: Option, + /// The timestamp when this request was created. + pub timestamp: Option, + /// Whether the knock request has been marked as `seen` so it can be + /// filtered by the client. + pub is_seen: bool, + /// A set of actions to perform for this knock request. + pub actions: Arc, +} + +/// A set of actions to perform for a knock request. +#[derive(Debug, Clone, uniffi::Object)] +pub struct KnockRequestActions { + inner: matrix_sdk::room::knock_requests::KnockRequest, +} + +#[matrix_sdk_ffi_macros::export] +impl KnockRequestActions { + /// Accepts the knock request by inviting the user to the room. + pub async fn accept(&self) -> Result<(), ClientError> { + self.inner.accept().await.map_err(Into::into) + } + + /// Declines the knock request by kicking the user from the room with an + /// optional reason. + pub async fn decline(&self, reason: Option) -> Result<(), ClientError> { + self.inner.decline(reason.as_deref()).await.map_err(Into::into) + } + + /// Declines the knock request by banning the user from the room with an + /// optional reason. + pub async fn decline_and_ban(&self, reason: Option) -> Result<(), ClientError> { + self.inner.decline_and_ban(reason.as_deref()).await.map_err(Into::into) + } + + /// Marks the knock request as 'seen'. + /// + /// **IMPORTANT**: this won't update the current reference to this request, + /// a new one with the updated value should be emitted instead. + pub async fn mark_as_seen(&self) -> Result<(), ClientError> { + self.inner.mark_as_seen().await.map_err(Into::into) + } } /// Generates a `matrix.to` permalink to the given room alias. diff --git a/crates/matrix-sdk-base/Cargo.toml b/crates/matrix-sdk-base/Cargo.toml index 656beca6ef8..58a9c50e64c 100644 --- a/crates/matrix-sdk-base/Cargo.toml +++ b/crates/matrix-sdk-base/Cargo.toml @@ -51,7 +51,7 @@ assert_matches2 = { workspace = true, optional = true } async-trait = { workspace = true } bitflags = { version = "2.6.0", features = ["serde"] } decancer = "3.2.8" -eyeball = { workspace = true } +eyeball = { workspace = true, features = ["async-lock"] } eyeball-im = { workspace = true } futures-util = { workspace = true } growable-bloom-filter = { workspace = true } diff --git a/crates/matrix-sdk-base/src/deserialized_responses.rs b/crates/matrix-sdk-base/src/deserialized_responses.rs index 183a02da531..1f4bac92903 100644 --- a/crates/matrix-sdk-base/src/deserialized_responses.rs +++ b/crates/matrix-sdk-base/src/deserialized_responses.rs @@ -30,7 +30,7 @@ use ruma::{ StateEventContent, StaticStateEventContent, StrippedStateEvent, SyncStateEvent, }, serde::Raw, - EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedRoomId, OwnedUserId, UserId, + EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedRoomId, OwnedUserId, UInt, UserId, }; use serde::Serialize; use unicode_normalization::UnicodeNormalization; @@ -476,6 +476,23 @@ impl MemberEvent { .unwrap_or_else(|| self.user_id().localpart()), ) } + + /// The optional reason why the membership changed. + pub fn reason(&self) -> Option<&str> { + match self { + MemberEvent::Sync(SyncStateEvent::Original(c)) => c.content.reason.as_deref(), + MemberEvent::Stripped(e) => e.content.reason.as_deref(), + _ => None, + } + } + + /// The optional timestamp for this member event. + pub fn timestamp(&self) -> Option { + match self { + MemberEvent::Sync(SyncStateEvent::Original(c)) => Some(c.origin_server_ts.0), + _ => None, + } + } } impl SyncOrStrippedState { diff --git a/crates/matrix-sdk-base/src/rooms/normal.rs b/crates/matrix-sdk-base/src/rooms/normal.rs index 1c8e1f1ea64..a7a5d4ea983 100644 --- a/crates/matrix-sdk-base/src/rooms/normal.rs +++ b/crates/matrix-sdk-base/src/rooms/normal.rs @@ -22,7 +22,7 @@ use std::{ use as_variant::as_variant; use bitflags::bitflags; -use eyeball::{SharedObservable, Subscriber}; +use eyeball::{AsyncLock, ObservableWriteGuard, SharedObservable, Subscriber}; use futures_util::{Stream, StreamExt}; #[cfg(feature = "experimental-sliding-sync")] use matrix_sdk_common::deserialized_responses::TimelineEventKind; @@ -52,7 +52,7 @@ use ruma::{ }, tag::{TagEventContent, Tags}, AnyRoomAccountDataEvent, AnyStrippedStateEvent, AnySyncStateEvent, - RoomAccountDataEventType, SyncStateEvent, + RoomAccountDataEventType, StateEventType, SyncStateEvent, }, room::RoomType, serde::Raw, @@ -77,7 +77,8 @@ use crate::{ read_receipts::RoomReadReceipts, store::{DynStateStore, Result as StoreResult, StateStoreExt}, sync::UnreadNotificationsCount, - Error, MinimalStateEvent, OriginalMinimalStateEvent, RoomMemberships, + Error, MinimalStateEvent, OriginalMinimalStateEvent, RoomMemberships, StateStoreDataKey, + StateStoreDataValue, StoreError, }; /// Indicates that a notable update of `RoomInfo` has been applied, and why. @@ -167,6 +168,12 @@ pub struct Room { /// to disk but held in memory. #[cfg(all(feature = "e2e-encryption", feature = "experimental-sliding-sync"))] pub latest_encrypted_events: Arc>>>, + + /// A map for ids of room membership events in the knocking state linked to + /// the user id of the user affected by the member event, that the current + /// user has marked as seen so they can be ignored. + pub seen_knock_request_ids_map: + SharedObservable>, AsyncLock>, } /// The room summary containing member counts and members that should be used to @@ -289,6 +296,7 @@ impl Room { Self::MAX_ENCRYPTED_EVENTS, ))), room_info_notable_update_sender, + seen_knock_request_ids_map: SharedObservable::new_async(None), } } @@ -1169,6 +1177,88 @@ impl Room { pub fn pinned_event_ids(&self) -> Option> { self.inner.read().pinned_event_ids() } + + /// Mark a list of requests to join the room as seen, given their state + /// event ids. + pub async fn mark_knock_requests_as_seen(&self, user_ids: &[OwnedUserId]) -> StoreResult<()> { + let raw_user_ids: Vec<&str> = user_ids.iter().map(|id| id.as_str()).collect(); + let member_raw_events = self + .store + .get_state_events_for_keys(self.room_id(), StateEventType::RoomMember, &raw_user_ids) + .await?; + let mut event_to_user_ids = Vec::with_capacity(member_raw_events.len()); + + // Map the list of events ids to their user ids, if they are event ids for knock + // membership events. Log an error and continue otherwise. + for raw_event in member_raw_events { + let event = raw_event.cast::().deserialize()?; + match event { + SyncOrStrippedState::Sync(SyncStateEvent::Original(event)) => { + if event.content.membership == MembershipState::Knock { + event_to_user_ids.push((event.event_id, event.state_key)) + } else { + warn!("Could not mark knock event as seen: event {} for user {} is not in Knock membership state.", event.event_id, event.state_key); + } + } + _ => warn!( + "Could not mark knock event as seen: event for user {} is not valid.", + event.state_key() + ), + } + } + + let mut current_seen_events_guard = self.seen_knock_request_ids_map.write().await; + // We're not calling `get_seen_join_request_ids` here because we need to keep + // the Mutex's guard until we've updated the data + let mut current_seen_events = if current_seen_events_guard.is_none() { + self.load_cached_knock_request_ids().await? + } else { + current_seen_events_guard.clone().unwrap() + }; + + current_seen_events.extend(event_to_user_ids); + + ObservableWriteGuard::set( + &mut current_seen_events_guard, + Some(current_seen_events.clone()), + ); + + self.store + .set_kv_data( + StateStoreDataKey::SeenKnockRequests(self.room_id()), + StateStoreDataValue::SeenKnockRequests(current_seen_events), + ) + .await?; + + Ok(()) + } + + /// Get the list of seen knock request event ids in this room. + pub async fn get_seen_knock_request_ids( + &self, + ) -> Result, StoreError> { + let mut guard = self.seen_knock_request_ids_map.write().await; + if guard.is_none() { + ObservableWriteGuard::set( + &mut guard, + Some(self.load_cached_knock_request_ids().await?), + ); + } + Ok(guard.clone().unwrap_or_default()) + } + + /// This loads the current list of seen knock request ids from the state + /// store. + async fn load_cached_knock_request_ids( + &self, + ) -> StoreResult> { + Ok(self + .store + .get_kv_data(StateStoreDataKey::SeenKnockRequests(self.room_id())) + .await? + .and_then(|v| v.into_seen_knock_requests()) + .unwrap_or_default()) + } } // See https://github.com/matrix-org/matrix-rust-sdk/pull/3749#issuecomment-2312939823. @@ -1348,6 +1438,11 @@ impl RoomInfo { self.members_synced = false; } + /// Returns whether the room members are synced. + pub fn are_members_synced(&self) -> bool { + self.members_synced + } + /// Mark this Room as still missing some state information. pub fn mark_state_partially_synced(&mut self) { self.sync_info = SyncInfo::PartiallySynced; diff --git a/crates/matrix-sdk-base/src/store/memory_store.rs b/crates/matrix-sdk-base/src/store/memory_store.rs index abc5fca09b4..9148c9b34da 100644 --- a/crates/matrix-sdk-base/src/store/memory_store.rs +++ b/crates/matrix-sdk-base/src/store/memory_store.rs @@ -82,6 +82,7 @@ struct MemoryStoreInner { custom: HashMap, Vec>, send_queue_events: BTreeMap>, dependent_send_queue_events: BTreeMap>, + seen_knock_requests: BTreeMap>, } /// In-memory, non-persistent implementation of the `StateStore`. @@ -168,6 +169,11 @@ impl StateStore for MemoryStore { StateStoreDataKey::ComposerDraft(room_id) => { inner.composer_drafts.get(room_id).cloned().map(StateStoreDataValue::ComposerDraft) } + StateStoreDataKey::SeenKnockRequests(room_id) => inner + .seen_knock_requests + .get(room_id) + .cloned() + .map(StateStoreDataValue::SeenKnockRequests), }) } @@ -222,6 +228,14 @@ impl StateStore for MemoryStore { .expect("Session data not containing server capabilities"), ); } + StateStoreDataKey::SeenKnockRequests(room_id) => { + inner.seen_knock_requests.insert( + room_id.to_owned(), + value + .into_seen_knock_requests() + .expect("Session data is not a set of seen join request ids"), + ); + } } Ok(()) @@ -245,6 +259,9 @@ impl StateStore for MemoryStore { StateStoreDataKey::ComposerDraft(room_id) => { inner.composer_drafts.remove(room_id); } + StateStoreDataKey::SeenKnockRequests(room_id) => { + inner.seen_knock_requests.remove(room_id); + } } Ok(()) } diff --git a/crates/matrix-sdk-base/src/store/traits.rs b/crates/matrix-sdk-base/src/store/traits.rs index 6e34f4fe263..5f651483f5b 100644 --- a/crates/matrix-sdk-base/src/store/traits.rs +++ b/crates/matrix-sdk-base/src/store/traits.rs @@ -1022,6 +1022,9 @@ pub enum StateStoreDataValue { /// /// [`ComposerDraft`]: Self::ComposerDraft ComposerDraft(ComposerDraft), + + /// A list of knock request ids marked as seen in a room. + SeenKnockRequests(BTreeMap), } /// Current draft of the composer for the room. @@ -1088,6 +1091,11 @@ impl StateStoreDataValue { pub fn into_server_capabilities(self) -> Option { as_variant!(self, Self::ServerCapabilities) } + + /// Get this value if it is the data for the ignored join requests. + pub fn into_seen_knock_requests(self) -> Option> { + as_variant!(self, Self::SeenKnockRequests) + } } /// A key for key-value data. @@ -1117,6 +1125,9 @@ pub enum StateStoreDataKey<'a> { /// /// [`ComposerDraft`]: Self::ComposerDraft ComposerDraft(&'a RoomId), + + /// A list of knock request ids marked as seen in a room. + SeenKnockRequests(&'a RoomId), } impl StateStoreDataKey<'_> { @@ -1142,6 +1153,10 @@ impl StateStoreDataKey<'_> { /// Key prefix to use for the [`ComposerDraft`][Self::ComposerDraft] /// variant. pub const COMPOSER_DRAFT: &'static str = "composer_draft"; + + /// Key prefix to use for the + /// [`SeenKnockRequests`][Self::SeenKnockRequests] variant. + pub const SEEN_KNOCK_REQUESTS: &'static str = "seen_knock_requests"; } #[cfg(test)] diff --git a/crates/matrix-sdk-indexeddb/src/state_store/mod.rs b/crates/matrix-sdk-indexeddb/src/state_store/mod.rs index b8ca7442b27..01d386f354f 100644 --- a/crates/matrix-sdk-indexeddb/src/state_store/mod.rs +++ b/crates/matrix-sdk-indexeddb/src/state_store/mod.rs @@ -419,6 +419,9 @@ impl IndexeddbStateStore { StateStoreDataKey::ComposerDraft(room_id) => { self.encode_key(keys::KV, (StateStoreDataKey::COMPOSER_DRAFT, room_id)) } + StateStoreDataKey::SeenKnockRequests(room_id) => { + self.encode_key(keys::KV, (StateStoreDataKey::SEEN_KNOCK_REQUESTS, room_id)) + } } } } @@ -537,6 +540,10 @@ impl_state_store!({ .map(|f| self.deserialize_value::(&f)) .transpose()? .map(StateStoreDataValue::ComposerDraft), + StateStoreDataKey::SeenKnockRequests(_) => value + .map(|f| self.deserialize_value::>(&f)) + .transpose()? + .map(StateStoreDataValue::SeenKnockRequests), }; Ok(value) @@ -574,6 +581,11 @@ impl_state_store!({ StateStoreDataKey::ComposerDraft(_) => self.serialize_value( &value.into_composer_draft().expect("Session data not a composer draft"), ), + StateStoreDataKey::SeenKnockRequests(_) => self.serialize_value( + &value + .into_seen_knock_requests() + .expect("Session data is not a set of seen knock request ids"), + ), }; let tx = diff --git a/crates/matrix-sdk-sqlite/src/state_store.rs b/crates/matrix-sdk-sqlite/src/state_store.rs index 36ff843cc71..adfd9d5b5a3 100644 --- a/crates/matrix-sdk-sqlite/src/state_store.rs +++ b/crates/matrix-sdk-sqlite/src/state_store.rs @@ -390,6 +390,9 @@ impl SqliteStateStore { StateStoreDataKey::ComposerDraft(room_id) => { Cow::Owned(format!("{}:{room_id}", StateStoreDataKey::COMPOSER_DRAFT)) } + StateStoreDataKey::SeenKnockRequests(room_id) => { + Cow::Owned(format!("{}:{room_id}", StateStoreDataKey::SEEN_KNOCK_REQUESTS)) + } }; self.encode_key(keys::KV_BLOB, &*key_s) @@ -995,6 +998,9 @@ impl StateStore for SqliteStateStore { StateStoreDataKey::ComposerDraft(_) => { StateStoreDataValue::ComposerDraft(self.deserialize_value(&data)?) } + StateStoreDataKey::SeenKnockRequests(_) => { + StateStoreDataValue::SeenKnockRequests(self.deserialize_value(&data)?) + } }) }) .transpose() @@ -1029,6 +1035,11 @@ impl StateStore for SqliteStateStore { StateStoreDataKey::ComposerDraft(_) => self.serialize_value( &value.into_composer_draft().expect("Session data not a composer draft"), )?, + StateStoreDataKey::SeenKnockRequests(_) => self.serialize_value( + &value + .into_seen_knock_requests() + .expect("Session data is not a set of seen knock request ids"), + )?, }; self.acquire() diff --git a/crates/matrix-sdk/src/room/knock_requests.rs b/crates/matrix-sdk/src/room/knock_requests.rs new file mode 100644 index 00000000000..de1409f5086 --- /dev/null +++ b/crates/matrix-sdk/src/room/knock_requests.rs @@ -0,0 +1,222 @@ +// Copyright 2024 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use js_int::UInt; +use ruma::{EventId, OwnedEventId, OwnedMxcUri, OwnedUserId, RoomId}; + +use crate::{room::RoomMember, Error, Room}; + +/// A request to join a room with `knock` join rule. +#[derive(Debug, Clone)] +pub struct KnockRequest { + room: Room, + /// The event id of the event containing knock membership change. + pub event_id: OwnedEventId, + /// The timestamp when this request was created. + pub timestamp: Option, + /// Some general room member info to display. + pub member_info: KnockRequestMemberInfo, + /// Whether it's been marked as 'seen' by the client. + pub is_seen: bool, +} + +impl KnockRequest { + pub(crate) fn new( + room: &Room, + event_id: &EventId, + timestamp: Option, + member: KnockRequestMemberInfo, + is_seen: bool, + ) -> Self { + Self { + room: room.clone(), + event_id: event_id.to_owned(), + timestamp, + member_info: member, + is_seen, + } + } + + /// The room id for the `Room` from whose access is requested. + pub fn room_id(&self) -> &RoomId { + self.room.room_id() + } + + /// Marks the knock request as 'seen' so the client can ignore it in the + /// future. + pub async fn mark_as_seen(&self) -> Result<(), Error> { + self.room.mark_knock_requests_as_seen(&[self.member_info.user_id.to_owned()]).await?; + Ok(()) + } + + /// Accepts the knock request by inviting the user to the room. + pub async fn accept(&self) -> Result<(), Error> { + self.room.invite_user_by_id(&self.member_info.user_id).await + } + + /// Declines the knock request by kicking the user from the room, with an + /// optional reason. + pub async fn decline(&self, reason: Option<&str>) -> Result<(), Error> { + self.room.kick_user(&self.member_info.user_id, reason).await + } + + /// Declines the knock request by banning the user from the room, with an + /// optional reason. + pub async fn decline_and_ban(&self, reason: Option<&str>) -> Result<(), Error> { + self.room.ban_user(&self.member_info.user_id, reason).await + } +} + +/// General room member info to display along with the join request. +#[derive(Debug, Clone)] +pub struct KnockRequestMemberInfo { + /// The user id for the room member requesting access. + pub user_id: OwnedUserId, + /// The optional display name of the room member requesting access. + pub display_name: Option, + /// The optional avatar url of the room member requesting access. + pub avatar_url: Option, + /// An optional reason why the user wants access to the room. + pub reason: Option, +} + +impl KnockRequestMemberInfo { + pub(crate) fn from_member(member: &RoomMember) -> Self { + Self { + user_id: member.user_id().to_owned(), + display_name: member.display_name().map(ToOwned::to_owned), + avatar_url: member.avatar_url().map(ToOwned::to_owned), + reason: member.event().reason().map(ToOwned::to_owned), + } + } +} + +// The http mocking library is not supported for wasm32 +#[cfg(all(test, not(target_arch = "wasm32")))] +mod tests { + use matrix_sdk_test::{async_test, event_factory::EventFactory, JoinedRoomBuilder}; + use ruma::{ + event_id, + events::room::member::{MembershipState, RoomMemberEventContent}, + owned_user_id, room_id, user_id, EventId, + }; + + use crate::{ + room::knock_requests::{KnockRequest, KnockRequestMemberInfo}, + test_utils::mocks::MatrixMockServer, + Room, + }; + + #[async_test] + async fn test_mark_as_seen() { + let server = MatrixMockServer::new().await; + let client = server.client_builder().build().await; + let room_id = room_id!("!a:b.c"); + let event_id = event_id!("$a:b.c"); + let user_id = user_id!("@alice:b.c"); + + let f = EventFactory::new().room(room_id); + let joined_room_builder = JoinedRoomBuilder::new(room_id).add_state_bulk(vec![f + .event(RoomMemberEventContent::new(MembershipState::Knock)) + .event_id(event_id) + .sender(user_id) + .state_key(user_id) + .into_raw_timeline() + .cast()]); + let room = server.sync_room(&client, joined_room_builder).await; + + let knock_request = make_knock_request(&room, Some(event_id)); + + // When we mark the knock request as seen + knock_request.mark_as_seen().await.expect("Failed to mark as seen"); + + // Then we can check it was successfully marked as seen from the room + let seen_ids = + room.get_seen_knock_request_ids().await.expect("Failed to get seen join request ids"); + assert_eq!(seen_ids.len(), 1); + assert_eq!( + seen_ids.into_iter().next().expect("Couldn't load next item"), + (event_id.to_owned(), user_id.to_owned()) + ); + } + + #[async_test] + async fn test_accept() { + let server = MatrixMockServer::new().await; + let client = server.client_builder().build().await; + let room_id = room_id!("!a:b.c"); + + let room = server.sync_joined_room(&client, room_id).await; + + let knock_request = make_knock_request(&room, None); + + // The /invite endpoint must be called once + server.mock_invite_user_by_id().ok().mock_once().mount().await; + + // When we accept the knock request + knock_request.accept().await.expect("Failed to accept the request"); + } + + #[async_test] + async fn test_decline() { + let server = MatrixMockServer::new().await; + let client = server.client_builder().build().await; + let room_id = room_id!("!a:b.c"); + + let room = server.sync_joined_room(&client, room_id).await; + + let knock_request = make_knock_request(&room, None); + + // The /kick endpoint must be called once + server.mock_kick_user().ok().mock_once().mount().await; + + // When we decline the knock request + knock_request.decline(None).await.expect("Failed to decline the request"); + } + + #[async_test] + async fn test_decline_and_ban() { + let server = MatrixMockServer::new().await; + let client = server.client_builder().build().await; + let room_id = room_id!("!a:b.c"); + + let room = server.sync_joined_room(&client, room_id).await; + + let knock_request = make_knock_request(&room, None); + + // The /ban endpoint must be called once + server.mock_ban_user().ok().mock_once().mount().await; + + // When we decline the knock request and ban the user from the room + knock_request + .decline_and_ban(None) + .await + .expect("Failed to decline the request and ban the user"); + } + + fn make_knock_request(room: &Room, event_id: Option<&EventId>) -> KnockRequest { + KnockRequest::new( + room, + event_id.unwrap_or(event_id!("$a:b.c")), + None, + KnockRequestMemberInfo { + user_id: owned_user_id!("@alice:b.c"), + display_name: None, + avatar_url: None, + reason: None, + }, + false, + ) + } +} diff --git a/crates/matrix-sdk/src/room/mod.rs b/crates/matrix-sdk/src/room/mod.rs index 224b64ad352..30a3eed4de5 100644 --- a/crates/matrix-sdk/src/room/mod.rs +++ b/crates/matrix-sdk/src/room/mod.rs @@ -22,6 +22,7 @@ use std::{ time::Duration, }; +use async_stream::stream; #[cfg(all(feature = "e2e-encryption", not(target_arch = "wasm32")))] use async_trait::async_trait; use eyeball::SharedObservable; @@ -85,6 +86,7 @@ use ruma::{ avatar::{self, RoomAvatarEventContent}, encryption::RoomEncryptionEventContent, history_visibility::HistoryVisibility, + member::{MembershipChange, SyncRoomMemberEvent}, message::{ AudioInfo, AudioMessageEventContent, FileInfo, FileMessageEventContent, FormattedBody, ImageMessageEventContent, MessageType, RoomMessageEventContent, @@ -116,6 +118,7 @@ use ruma::{ use serde::de::DeserializeOwned; use thiserror::Error; use tokio::sync::broadcast; +use tokio_stream::StreamExt; use tracing::{debug, info, instrument, warn}; use self::futures::{SendAttachment, SendMessageLikeEvent, SendRawMessageLikeEvent}; @@ -135,7 +138,10 @@ use crate::{ live_location_share::ObservableLiveLocation, media::{MediaFormat, MediaRequestParameters}, notification_settings::{IsEncrypted, IsOneToOne, RoomNotificationMode}, - room::power_levels::{RoomPowerLevelChanges, RoomPowerLevelsExt}, + room::{ + knock_requests::{KnockRequest, KnockRequestMemberInfo}, + power_levels::{RoomPowerLevelChanges, RoomPowerLevelsExt}, + }, sync::RoomUpdate, utils::{IntoRawMessageLikeEventContent, IntoRawStateEventContent}, BaseRoom, Client, Error, HttpResult, Result, RoomState, TransmissionProgress, @@ -146,6 +152,8 @@ use crate::{crypto::types::events::CryptoContextInfo, encryption::backups::Backu pub mod edit; pub mod futures; pub mod identity_status_changes; +/// Contains code related to requests to join a room. +pub mod knock_requests; mod member; mod messages; pub mod power_levels; @@ -3205,6 +3213,131 @@ impl Room { pub fn observe_live_location_shares(&self) -> ObservableLiveLocation { ObservableLiveLocation::new(&self.client, self.room_id()) } + + /// Subscribe to knock requests in this `Room`. + /// + /// The current requests to join the room will be emitted immediately + /// when subscribing. + /// + /// A new set of knock requests will be emitted whenever: + /// - A new member event is received. + /// - A knock request is marked as seen. + /// - A sync is gappy (limited), so room membership information may be + /// outdated. + pub async fn subscribe_to_knock_requests( + &self, + ) -> Result>> { + let this = Arc::new(self.clone()); + + let room_member_events_observer = + self.client.observe_room_events::(this.room_id()); + + let current_seen_ids = self.get_seen_knock_request_ids().await?; + let mut seen_request_ids_stream = self + .seen_knock_request_ids_map + .subscribe() + .await + .map(|values| values.unwrap_or_default()); + + let mut room_info_stream = self.subscribe_info(); + + let combined_stream = stream! { + // Emit current requests to join + match this.get_current_join_requests(¤t_seen_ids).await { + Ok(initial_requests) => yield initial_requests, + Err(err) => warn!("Failed to get initial requests to join: {err}") + } + + let mut requests_stream = room_member_events_observer.subscribe(); + let mut seen_ids = current_seen_ids.clone(); + + loop { + // This is equivalent to a combine stream operation, triggering a new emission + // when any of the branches changes + tokio::select! { + Some((event, _)) = requests_stream.next() => { + if let Some(event) = event.as_original() { + // If we can calculate the membership change, try to emit only when needed + let emit = if event.prev_content().is_some() { + matches!(event.membership_change(), + MembershipChange::Banned | + MembershipChange::Knocked | + MembershipChange::KnockAccepted | + MembershipChange::KnockDenied | + MembershipChange::KnockRetracted + ) + } else { + // If we can't calculate the membership change, assume we need to + // emit updated values + true + }; + + if emit { + match this.get_current_join_requests(&seen_ids).await { + Ok(requests) => yield requests, + Err(err) => { + warn!("Failed to get updated knock requests on new member event: {err}") + } + } + } + } + } + + Some(new_seen_ids) = seen_request_ids_stream.next() => { + // Update the current seen ids + seen_ids = new_seen_ids; + + // If seen requests have changed we need to recalculate + // all the knock requests + match this.get_current_join_requests(&seen_ids).await { + Ok(requests) => yield requests, + Err(err) => { + warn!("Failed to get updated knock requests on seen ids changed: {err}") + } + } + } + + Some(room_info) = room_info_stream.next() => { + // We need to emit new items when we may have missing room members: + // this usually happens after a gappy (limited) sync + if !room_info.are_members_synced() { + match this.get_current_join_requests(&seen_ids).await { + Ok(requests) => yield requests, + Err(err) => { + warn!("Failed to get updated knock requests on gappy (limited) sync: {err}") + } + } + } + } + // If the streams in all branches are closed, stop the loop + else => break, + } + } + }; + + Ok(combined_stream) + } + + async fn get_current_join_requests( + &self, + seen_request_ids: &BTreeMap, + ) -> Result> { + Ok(self + .members(RoomMemberships::KNOCK) + .await? + .into_iter() + .filter_map(|member| { + let event_id = member.event().event_id()?; + Some(KnockRequest::new( + self, + event_id, + member.event().timestamp(), + KnockRequestMemberInfo::from_member(&member), + seen_request_ids.contains_key(event_id), + )) + }) + .collect()) + } } #[cfg(all(feature = "e2e-encryption", not(target_arch = "wasm32")))] @@ -3494,9 +3627,14 @@ pub struct TryFromReportedContentScoreError(()); mod tests { use matrix_sdk_base::{store::ComposerDraftType, ComposerDraft, SessionMeta}; use matrix_sdk_test::{ - async_test, test_json, JoinedRoomBuilder, StateTestEvent, SyncResponseBuilder, + async_test, event_factory::EventFactory, test_json, JoinedRoomBuilder, StateTestEvent, + SyncResponseBuilder, + }; + use ruma::{ + device_id, event_id, + events::room::member::{MembershipState, RoomMemberEventContent}, + int, room_id, user_id, }; - use ruma::{device_id, int, user_id}; use wiremock::{ matchers::{header, method, path_regex}, Mock, MockServer, ResponseTemplate, @@ -3506,7 +3644,7 @@ mod tests { use crate::{ config::RequestConfig, matrix_auth::{MatrixSession, MatrixSessionTokens}, - test_utils::logged_in_client, + test_utils::{logged_in_client, mocks::MatrixMockServer}, Client, }; @@ -3681,4 +3819,42 @@ mod tests { room.clear_composer_draft().await.unwrap(); assert_eq!(room.load_composer_draft().await.unwrap(), None); } + + #[async_test] + async fn test_mark_join_requests_as_seen() { + let server = MatrixMockServer::new().await; + let client = server.client_builder().build().await; + let event_id = event_id!("$a:b.c"); + let room_id = room_id!("!a:b.c"); + let user_id = user_id!("@alice:b.c"); + + let f = EventFactory::new().room(room_id); + let joined_room_builder = JoinedRoomBuilder::new(room_id).add_state_bulk(vec![f + .event(RoomMemberEventContent::new(MembershipState::Knock)) + .event_id(event_id) + .sender(user_id) + .state_key(user_id) + .into_raw_timeline() + .cast()]); + let room = server.sync_room(&client, joined_room_builder).await; + + // When loading the initial seen ids, there are none + let seen_ids = + room.get_seen_knock_request_ids().await.expect("Couldn't load seen join request ids"); + assert!(seen_ids.is_empty()); + + // We mark a random event id as seen + room.mark_knock_requests_as_seen(&[user_id.to_owned()]) + .await + .expect("Couldn't mark join request as seen"); + + // Then we can check it was successfully marked as seen + let seen_ids = + room.get_seen_knock_request_ids().await.expect("Couldn't load seen join request ids"); + assert_eq!(seen_ids.len(), 1); + assert_eq!( + seen_ids.into_iter().next().expect("No next value"), + (event_id.to_owned(), user_id.to_owned()) + ) + } } diff --git a/crates/matrix-sdk/src/test_utils/mocks.rs b/crates/matrix-sdk/src/test_utils/mocks.rs index 74d25a5ac01..19ffee007ea 100644 --- a/crates/matrix-sdk/src/test_utils/mocks.rs +++ b/crates/matrix-sdk/src/test_utils/mocks.rs @@ -29,7 +29,10 @@ use matrix_sdk_test::{ }; use ruma::{ directory::PublicRoomsChunk, - events::{AnyStateEvent, AnyTimelineEvent, MessageLikeEventType, StateEventType}, + events::{ + room::member::RoomMemberEvent, AnyStateEvent, AnyTimelineEvent, MessageLikeEventType, + StateEventType, + }, serde::Raw, time::Duration, MxcUri, OwnedEventId, OwnedRoomId, RoomId, ServerName, @@ -607,6 +610,138 @@ impl MatrixMockServer { .and(header("authorization", "Bearer 1234")); MockEndpoint { mock, server: &self.server, endpoint: DeleteRoomKeysVersionEndpoint } } + + /// Create a prebuilt mock for getting the room members in a room. + /// + /// # Examples + /// + /// ``` # + /// tokio_test::block_on(async { + /// use matrix_sdk_base::RoomMemberships; + /// use ruma::events::room::member::MembershipState; + /// use ruma::events::room::member::RoomMemberEventContent; + /// use ruma::user_id; + /// use matrix_sdk_test::event_factory::EventFactory; + /// use matrix_sdk::{ + /// ruma::{event_id, room_id}, + /// test_utils::mocks::MatrixMockServer, + /// }; + /// let mock_server = MatrixMockServer::new().await; + /// let client = mock_server.client_builder().build().await; + /// let event_id = event_id!("$id"); + /// let room_id = room_id!("!room_id:localhost"); + /// + /// let f = EventFactory::new().room(room_id); + /// let alice_user_id = user_id!("@alice:b.c"); + /// let alice_knock_event = f + /// .event(RoomMemberEventContent::new(MembershipState::Knock)) + /// .event_id(event_id) + /// .sender(alice_user_id) + /// .state_key(alice_user_id) + /// .into_raw_timeline() + /// .cast(); + /// + /// mock_server.mock_get_members().ok(vec![alice_knock_event]).mock_once().mount().await; + /// let room = mock_server.sync_joined_room(&client, room_id).await; + /// + /// let members = room.members(RoomMemberships::all()).await.unwrap(); + /// assert_eq!(members.len(), 1); + /// # }); + /// ``` + pub fn mock_get_members(&self) -> MockEndpoint<'_, GetRoomMembersEndpoint> { + let mock = + Mock::given(method("GET")).and(path_regex(r"^/_matrix/client/v3/rooms/.*/members$")); + MockEndpoint { mock, server: &self.server, endpoint: GetRoomMembersEndpoint } + } + + /// Creates a prebuilt mock for inviting a user to a room by its id. + /// + /// # Examples + /// + /// ``` + /// # use ruma::user_id; + /// tokio_test::block_on(async { + /// use matrix_sdk::{ + /// ruma::room_id, + /// test_utils::mocks::MatrixMockServer, + /// }; + /// + /// let mock_server = MatrixMockServer::new().await; + /// let client = mock_server.client_builder().build().await; + /// + /// mock_server.mock_invite_user_by_id().ok().mock_once().mount().await; + /// + /// let room = mock_server + /// .sync_joined_room(&client, room_id!("!room_id:localhost")) + /// .await; + /// + /// room.invite_user_by_id(user_id!("@alice:localhost")).await.unwrap(); + /// # anyhow::Ok(()) }); + /// ``` + pub fn mock_invite_user_by_id(&self) -> MockEndpoint<'_, InviteUserByIdEndpoint> { + let mock = + Mock::given(method("POST")).and(path_regex(r"^/_matrix/client/v3/rooms/.*/invite$")); + MockEndpoint { mock, server: &self.server, endpoint: InviteUserByIdEndpoint } + } + + /// Creates a prebuilt mock for kicking a user from a room. + /// + /// # Examples + /// + /// ``` + /// # use ruma::user_id; + /// tokio_test::block_on(async { + /// use matrix_sdk::{ + /// ruma::room_id, + /// test_utils::mocks::MatrixMockServer, + /// }; + /// + /// let mock_server = MatrixMockServer::new().await; + /// let client = mock_server.client_builder().build().await; + /// + /// mock_server.mock_kick_user().ok().mock_once().mount().await; + /// + /// let room = mock_server + /// .sync_joined_room(&client, room_id!("!room_id:localhost")) + /// .await; + /// + /// room.kick_user(user_id!("@alice:localhost"), None).await.unwrap(); + /// # anyhow::Ok(()) }); + /// ``` + pub fn mock_kick_user(&self) -> MockEndpoint<'_, KickUserEndpoint> { + let mock = + Mock::given(method("POST")).and(path_regex(r"^/_matrix/client/v3/rooms/.*/kick")); + MockEndpoint { mock, server: &self.server, endpoint: KickUserEndpoint } + } + + /// Creates a prebuilt mock for banning a user from a room. + /// + /// # Examples + /// + /// ``` + /// # use ruma::user_id; + /// tokio_test::block_on(async { + /// use matrix_sdk::{ + /// ruma::room_id, + /// test_utils::mocks::MatrixMockServer, + /// }; + /// + /// let mock_server = MatrixMockServer::new().await; + /// let client = mock_server.client_builder().build().await; + /// + /// mock_server.mock_ban_user().ok().mock_once().mount().await; + /// + /// let room = mock_server + /// .sync_joined_room(&client, room_id!("!room_id:localhost")) + /// .await; + /// + /// room.ban_user(user_id!("@alice:localhost"), None).await.unwrap(); + /// # anyhow::Ok(()) }); + /// ``` + pub fn mock_ban_user(&self) -> MockEndpoint<'_, BanUserEndpoint> { + let mock = Mock::given(method("POST")).and(path_regex(r"^/_matrix/client/v3/rooms/.*/ban")); + MockEndpoint { mock, server: &self.server, endpoint: BanUserEndpoint } + } } /// Parameter to [`MatrixMockServer::sync_room`]. @@ -1023,7 +1158,7 @@ impl<'a> MockEndpoint<'a, RoomSendEndpoint> { /// /// let response = room.client().send(r, None).await.unwrap(); /// // The delayed `m.room.message` event type should be mocked by the server. - /// assert_eq!("$some_id", response.delay_id); + /// assert_eq!("$some_id", response.delay_id); /// # anyhow::Ok(()) }); /// ``` pub fn with_delay(self, delay: Duration) -> Self { @@ -1761,3 +1896,49 @@ impl<'a> MockEndpoint<'a, DeleteRoomKeysVersionEndpoint> { MatrixMock { server: self.server, mock } } } + +/// A prebuilt mock for `GET /members` request. +pub struct GetRoomMembersEndpoint; + +impl<'a> MockEndpoint<'a, GetRoomMembersEndpoint> { + /// Returns a successful get members request with a list of members. + pub fn ok(self, members: Vec>) -> MatrixMock<'a> { + let mock = self.mock.respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "chunk": members, + }))); + MatrixMock { server: self.server, mock } + } +} + +/// A prebuilt mock for `POST /invite` request. +pub struct InviteUserByIdEndpoint; + +impl<'a> MockEndpoint<'a, InviteUserByIdEndpoint> { + /// Returns a successful invite user by id request. + pub fn ok(self) -> MatrixMock<'a> { + let mock = self.mock.respond_with(ResponseTemplate::new(200).set_body_json(json!({}))); + MatrixMock { server: self.server, mock } + } +} + +/// A prebuilt mock for `POST /kick` request. +pub struct KickUserEndpoint; + +impl<'a> MockEndpoint<'a, KickUserEndpoint> { + /// Returns a successful kick user request. + pub fn ok(self) -> MatrixMock<'a> { + let mock = self.mock.respond_with(ResponseTemplate::new(200).set_body_json(json!({}))); + MatrixMock { server: self.server, mock } + } +} + +/// A prebuilt mock for `POST /ban` request. +pub struct BanUserEndpoint; + +impl<'a> MockEndpoint<'a, BanUserEndpoint> { + /// Returns a successful ban user request. + pub fn ok(self) -> MatrixMock<'a> { + let mock = self.mock.respond_with(ResponseTemplate::new(200).set_body_json(json!({}))); + MatrixMock { server: self.server, mock } + } +} diff --git a/crates/matrix-sdk/tests/integration/room/joined.rs b/crates/matrix-sdk/tests/integration/room/joined.rs index fa0b3f66af2..4e6ca00de93 100644 --- a/crates/matrix-sdk/tests/integration/room/joined.rs +++ b/crates/matrix-sdk/tests/integration/room/joined.rs @@ -3,8 +3,9 @@ use std::{ time::Duration, }; -use futures_util::future::join_all; +use futures_util::{future::join_all, pin_mut}; use matrix_sdk::{ + assert_next_with_timeout, config::SyncSettings, room::{edit::EditedContent, Receipts, ReportedContentScore, RoomMemberRole}, test_utils::mocks::MatrixMockServer, @@ -24,12 +25,16 @@ use ruma::{ events::{ direct::DirectUserIdentifier, receipt::ReceiptThread, - room::message::{RoomMessageEventContent, RoomMessageEventContentWithoutRelation}, + room::{ + member::{MembershipState, RoomMemberEventContent}, + message::{RoomMessageEventContent, RoomMessageEventContentWithoutRelation}, + }, TimelineEventType, }, int, mxc_uri, owned_event_id, room_id, thirdparty, user_id, OwnedUserId, TransactionId, }; use serde_json::{from_value, json, Value}; +use stream_assert::assert_pending; use wiremock::{ matchers::{body_json, body_partial_json, header, method, path_regex}, Mock, ResponseTemplate, @@ -833,3 +838,113 @@ async fn test_enable_encryption_doesnt_stay_unencrypted() { assert!(room.is_encrypted().await.unwrap()); } + +#[async_test] +async fn test_subscribe_to_requests_to_join() { + let server = MatrixMockServer::new().await; + let client = server.client_builder().build().await; + + server.mock_room_state_encryption().plain().mount().await; + + let room_id = room_id!("!a:b.c"); + let f = EventFactory::new().room(room_id); + + let user_id = user_id!("@alice:b.c"); + let knock_event_id = event_id!("$alice-knock:b.c"); + let knock_event = f + .event(RoomMemberEventContent::new(MembershipState::Knock)) + .event_id(knock_event_id) + .sender(user_id) + .state_key(user_id) + .into_raw_timeline() + .cast(); + + server.mock_get_members().ok(vec![knock_event]).mock_once().mount().await; + + let room = server.sync_joined_room(&client, room_id).await; + let stream = room.subscribe_to_knock_requests().await.unwrap(); + + pin_mut!(stream); + + // We receive an initial knock request from Alice + let initial = assert_next_with_timeout!(stream, 100); + assert_eq!(initial.len(), 1); + + let knock_request = &initial[0]; + assert_eq!(knock_request.event_id, knock_event_id); + assert!(!knock_request.is_seen); + + // We then mark the knock request as seen + room.mark_knock_requests_as_seen(&[user_id.to_owned()]).await.unwrap(); + + // Now it's received again as seen + let seen = assert_next_with_timeout!(stream, 100); + assert_eq!(initial.len(), 1); + let seen_knock = &seen[0]; + assert_eq!(seen_knock.event_id, knock_event_id); + assert!(seen_knock.is_seen); + + // If we then receive a new member event for Alice that's not 'knock' + let joined_room_builder = JoinedRoomBuilder::new(room_id).add_state_bulk(vec![f + .event(RoomMemberEventContent::new(MembershipState::Invite)) + .sender(user_id) + .state_key(user_id) + .into_raw_timeline() + .cast()]); + server.sync_room(&client, joined_room_builder).await; + + // The knock requests are now empty + let updated_requests = assert_next_with_timeout!(stream, 100); + assert!(updated_requests.is_empty()); + + // There should be no other knock requests + assert_pending!(stream) +} + +#[async_test] +async fn test_subscribe_to_requests_to_join_reloads_members_on_limited_sync() { + let server = MatrixMockServer::new().await; + let client = server.client_builder().build().await; + + server.mock_room_state_encryption().plain().mount().await; + + let room_id = room_id!("!a:b.c"); + let f = EventFactory::new().room(room_id); + + let user_id = user_id!("@alice:b.c"); + let knock_event = f + .event(RoomMemberEventContent::new(MembershipState::Knock)) + .sender(user_id) + .state_key(user_id) + .into_raw_timeline() + .cast(); + + server + .mock_get_members() + .ok(vec![knock_event]) + // The endpoint will be called twice: + // 1. For the initial loading of room members. + // 2. When a gappy (limited) sync is received. + .expect(2) + .mount() + .await; + + let room = server.sync_joined_room(&client, room_id).await; + let stream = room.subscribe_to_knock_requests().await.unwrap(); + + pin_mut!(stream); + + // We receive an initial knock request from Alice + let initial = assert_next_with_timeout!(stream, 500); + assert!(!initial.is_empty()); + + // This limited sync should trigger a new emission of knock requests, with a + // reloading of the room members + server.sync_room(&client, JoinedRoomBuilder::new(room_id).set_timeline_limited()).await; + + // We should receive a new list of knock requests + assert_next_with_timeout!(stream, 500); + + // There should be no other knock requests + assert_pending!(stream) +}