diff --git a/crates/cli/src/commands/manage.rs b/crates/cli/src/commands/manage.rs index b8d6c6a52..c76ec7bc0 100644 --- a/crates/cli/src/commands/manage.rs +++ b/crates/cli/src/commands/manage.rs @@ -26,11 +26,12 @@ use mas_handlers::HttpClientFactory; use mas_matrix::HomeserverConnection; use mas_matrix_synapse::SynapseConnection; use mas_storage::{ - compat::{CompatAccessTokenRepository, CompatSessionRepository}, + compat::{CompatAccessTokenRepository, CompatSessionFilter, CompatSessionRepository}, job::{ DeactivateUserJob, JobRepositoryExt, ProvisionUserJob, ReactivateUserJob, SyncDevicesJob, }, - user::{UserEmailRepository, UserPasswordRepository, UserRepository}, + oauth2::OAuth2SessionFilter, + user::{BrowserSessionFilter, UserEmailRepository, UserPasswordRepository, UserRepository}, Clock, RepositoryAccess, SystemClock, }; use mas_storage_pg::{DatabaseError, PgRepository}; @@ -348,83 +349,43 @@ impl Options { .await? .context("User not found")?; - let compat_sessions_ids: Vec = sqlx::query_scalar( - r" - SELECT compat_session_id FROM compat_sessions - WHERE user_id = $1 AND finished_at IS NULL - ", - ) - .bind(Uuid::from(user.id)) - .fetch_all(&mut **repo) - .await?; - - for id in compat_sessions_ids { - let id = id.into(); - let compat_session = repo - .compat_session() - .lookup(id) - .await? - .context("Session not found")?; - info!(%compat_session.id, %compat_session.device, "Killing compat session"); - - if dry_run { - continue; - } - } - - let oauth2_sessions_ids: Vec = sqlx::query_scalar( - r" - SELECT oauth2_sessions.oauth2_session_id - FROM oauth2_sessions - INNER JOIN user_sessions USING (user_session_id) - WHERE user_sessions.user_id = $1 AND oauth2_sessions.finished_at IS NULL - ", - ) - .bind(Uuid::from(user.id)) - .fetch_all(&mut **repo) - .await?; - - for id in oauth2_sessions_ids { - let id = id.into(); - let oauth2_session = repo - .oauth2_session() - .lookup(id) - .await? - .context("Session not found")?; - info!(%oauth2_session.id, %oauth2_session.scope, "Killing oauth2 session"); + let filter = CompatSessionFilter::new().for_user(&user).active_only(); + let affected = if dry_run { + repo.compat_session().count(filter).await? + } else { + repo.compat_session().finish_bulk(&clock, filter).await? + }; - if dry_run { - continue; - } - repo.oauth2_session().finish(&clock, oauth2_session).await?; + match affected { + 0 => info!("No active compatibility sessions to end"), + 1 => info!("Ended 1 active compatibility session"), + _ => info!("Ended {affected} active compatibility sessions"), } - let user_sessions_ids: Vec = sqlx::query_scalar( - r" - SELECT user_session_id FROM user_sessions - WHERE user_id = $1 AND finished_at IS NULL - ", - ) - .bind(Uuid::from(user.id)) - .fetch_all(&mut **repo) - .await?; + let filter = OAuth2SessionFilter::new().for_user(&user).active_only(); + let affected = if dry_run { + repo.oauth2_session().count(filter).await? + } else { + repo.oauth2_session().finish_bulk(&clock, filter).await? + }; - for id in user_sessions_ids { - let id = id.into(); - let browser_session = repo - .browser_session() - .lookup(id) - .await? - .context("Session not found")?; - info!(%browser_session.id, "Killing browser session"); + match affected { + 0 => info!("No active compatibility sessions to end"), + 1 => info!("Ended 1 active OAuth 2.0 session"), + _ => info!("Ended {affected} active OAuth 2.0 sessions"), + }; - if dry_run { - continue; - } + let filter = BrowserSessionFilter::new().for_user(&user).active_only(); + let affected = if dry_run { + repo.browser_session().count(filter).await? + } else { + repo.browser_session().finish_bulk(&clock, filter).await? + }; - repo.browser_session() - .finish(&clock, browser_session) - .await?; + match affected { + 0 => info!("No active browser sessions to end"), + 1 => info!("Ended 1 active browser session"), + _ => info!("Ended {affected} active browser sessions"), } // Schedule a job to sync the devices of the user with the homeserver diff --git a/crates/storage-pg/.sqlx/query-8a7461d5f633b7b441b9bc83f78900ccee5a76328a7fad97b650f7e4c921bd7b.json b/crates/storage-pg/.sqlx/query-03eee34f05df9c79f8ca5bfb1af339b3fcea95ba59395106318366a6ef432d85.json similarity index 62% rename from crates/storage-pg/.sqlx/query-8a7461d5f633b7b441b9bc83f78900ccee5a76328a7fad97b650f7e4c921bd7b.json rename to crates/storage-pg/.sqlx/query-03eee34f05df9c79f8ca5bfb1af339b3fcea95ba59395106318366a6ef432d85.json index db3f56f37..12c48424a 100644 --- a/crates/storage-pg/.sqlx/query-8a7461d5f633b7b441b9bc83f78900ccee5a76328a7fad97b650f7e4c921bd7b.json +++ b/crates/storage-pg/.sqlx/query-03eee34f05df9c79f8ca5bfb1af339b3fcea95ba59395106318366a6ef432d85.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "\n UPDATE user_sessions\n SET last_active_at = GREATEST(t.last_active_at, user_sessions.last_active_at)\n , last_active_ip = COALESCE(t.last_active_ip, user_sessions.last_active_ip)\n FROM (\n SELECT *\n FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[]) \n AS t(user_session_id, last_active_at, last_active_ip)\n ) AS t\n WHERE user_sessions.user_session_id = t.user_session_id\n ", + "query": "\n UPDATE user_sessions\n SET last_active_at = GREATEST(t.last_active_at, user_sessions.last_active_at)\n , last_active_ip = COALESCE(t.last_active_ip, user_sessions.last_active_ip)\n FROM (\n SELECT *\n FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[])\n AS t(user_session_id, last_active_at, last_active_ip)\n ) AS t\n WHERE user_sessions.user_session_id = t.user_session_id\n ", "describe": { "columns": [], "parameters": { @@ -12,5 +12,5 @@ }, "nullable": [] }, - "hash": "8a7461d5f633b7b441b9bc83f78900ccee5a76328a7fad97b650f7e4c921bd7b" + "hash": "03eee34f05df9c79f8ca5bfb1af339b3fcea95ba59395106318366a6ef432d85" } diff --git a/crates/storage-pg/.sqlx/query-4070549b235e059eaeccc4751b480ccb30ad5b62d933b4efb03491124a9361ad.json b/crates/storage-pg/.sqlx/query-4070549b235e059eaeccc4751b480ccb30ad5b62d933b4efb03491124a9361ad.json deleted file mode 100644 index ed53254fa..000000000 --- a/crates/storage-pg/.sqlx/query-4070549b235e059eaeccc4751b480ccb30ad5b62d933b4efb03491124a9361ad.json +++ /dev/null @@ -1,19 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "\n INSERT INTO compat_sessions \n (compat_session_id, user_id, device_id,\n user_session_id, created_at, is_synapse_admin)\n VALUES ($1, $2, $3, $4, $5, $6)\n ", - "describe": { - "columns": [], - "parameters": { - "Left": [ - "Uuid", - "Uuid", - "Text", - "Uuid", - "Timestamptz", - "Bool" - ] - }, - "nullable": [] - }, - "hash": "4070549b235e059eaeccc4751b480ccb30ad5b62d933b4efb03491124a9361ad" -} diff --git a/crates/storage-pg/.sqlx/query-d4d25682c10be7a3e3ee989fb9dae92e19023b8ecb6fe6e1d7cabe2cf0efd930.json b/crates/storage-pg/.sqlx/query-55bc51efddf7a1cf06610fdb20d46beca29964733338ea4fec2a29393f031c4f.json similarity index 61% rename from crates/storage-pg/.sqlx/query-d4d25682c10be7a3e3ee989fb9dae92e19023b8ecb6fe6e1d7cabe2cf0efd930.json rename to crates/storage-pg/.sqlx/query-55bc51efddf7a1cf06610fdb20d46beca29964733338ea4fec2a29393f031c4f.json index 14d447765..de0a51e4a 100644 --- a/crates/storage-pg/.sqlx/query-d4d25682c10be7a3e3ee989fb9dae92e19023b8ecb6fe6e1d7cabe2cf0efd930.json +++ b/crates/storage-pg/.sqlx/query-55bc51efddf7a1cf06610fdb20d46beca29964733338ea4fec2a29393f031c4f.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "\n UPDATE compat_sessions\n SET last_active_at = GREATEST(t.last_active_at, compat_sessions.last_active_at)\n , last_active_ip = COALESCE(t.last_active_ip, compat_sessions.last_active_ip)\n FROM (\n SELECT *\n FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[]) \n AS t(compat_session_id, last_active_at, last_active_ip)\n ) AS t\n WHERE compat_sessions.compat_session_id = t.compat_session_id\n ", + "query": "\n UPDATE compat_sessions\n SET last_active_at = GREATEST(t.last_active_at, compat_sessions.last_active_at)\n , last_active_ip = COALESCE(t.last_active_ip, compat_sessions.last_active_ip)\n FROM (\n SELECT *\n FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[])\n AS t(compat_session_id, last_active_at, last_active_ip)\n ) AS t\n WHERE compat_sessions.compat_session_id = t.compat_session_id\n ", "describe": { "columns": [], "parameters": { @@ -12,5 +12,5 @@ }, "nullable": [] }, - "hash": "d4d25682c10be7a3e3ee989fb9dae92e19023b8ecb6fe6e1d7cabe2cf0efd930" + "hash": "55bc51efddf7a1cf06610fdb20d46beca29964733338ea4fec2a29393f031c4f" } diff --git a/crates/storage-pg/.sqlx/query-cf1273b8aaaccedeb212a971d5e8e0dd23bfddab0ec08ee192783e103a1c4766.json b/crates/storage-pg/.sqlx/query-cf1273b8aaaccedeb212a971d5e8e0dd23bfddab0ec08ee192783e103a1c4766.json new file mode 100644 index 000000000..35f6b5973 --- /dev/null +++ b/crates/storage-pg/.sqlx/query-cf1273b8aaaccedeb212a971d5e8e0dd23bfddab0ec08ee192783e103a1c4766.json @@ -0,0 +1,19 @@ +{ + "db_name": "PostgreSQL", + "query": "\n INSERT INTO compat_sessions\n (compat_session_id, user_id, device_id,\n user_session_id, created_at, is_synapse_admin)\n VALUES ($1, $2, $3, $4, $5, $6)\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Uuid", + "Uuid", + "Text", + "Uuid", + "Timestamptz", + "Bool" + ] + }, + "nullable": [] + }, + "hash": "cf1273b8aaaccedeb212a971d5e8e0dd23bfddab0ec08ee192783e103a1c4766" +} diff --git a/crates/storage-pg/src/compat/mod.rs b/crates/storage-pg/src/compat/mod.rs index 092b8c8b8..45073667a 100644 --- a/crates/storage-pg/src/compat/mod.rs +++ b/crates/storage-pg/src/compat/mod.rs @@ -288,6 +288,16 @@ mod tests { .unwrap(), 1 ); + + // Check that we can batch finish sessions + let affected = repo + .compat_session() + .finish_bulk(&clock, all.sso_login_only().active_only()) + .await + .unwrap(); + assert_eq!(affected, 1); + assert_eq!(repo.compat_session().count(finished).await.unwrap(), 2); + assert_eq!(repo.compat_session().count(active).await.unwrap(), 0); } #[sqlx::test(migrator = "crate::MIGRATOR")] diff --git a/crates/storage-pg/src/compat/session.rs b/crates/storage-pg/src/compat/session.rs index 2b183253f..b85801089 100644 --- a/crates/storage-pg/src/compat/session.rs +++ b/crates/storage-pg/src/compat/session.rs @@ -271,7 +271,7 @@ impl<'c> CompatSessionRepository for PgCompatSessionRepository<'c> { sqlx::query!( r#" - INSERT INTO compat_sessions + INSERT INTO compat_sessions (compat_session_id, user_id, device_id, user_session_id, created_at, is_synapse_admin) VALUES ($1, $2, $3, $4, $5, $6) @@ -341,6 +341,64 @@ impl<'c> CompatSessionRepository for PgCompatSessionRepository<'c> { Ok(compat_session) } + #[tracing::instrument( + name = "db.compat_session.finish_bulk", + skip_all, + fields(db.statement), + err, + )] + async fn finish_bulk( + &mut self, + clock: &dyn Clock, + filter: CompatSessionFilter<'_>, + ) -> Result { + let finished_at = clock.now(); + let (sql, arguments) = Query::update() + .table(CompatSessions::Table) + .value(CompatSessions::FinishedAt, finished_at) + .and_where_option(filter.user().map(|user| { + Expr::col((CompatSessions::Table, CompatSessions::UserId)).eq(Uuid::from(user.id)) + })) + .and_where_option(filter.state().map(|state| { + if state.is_active() { + Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_null() + } else { + Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_not_null() + } + })) + .and_where_option(filter.auth_type().map(|auth_type| { + // This builds either a: + // `WHERE compat_session_id = ANY(...)` + // or a `WHERE compat_session_id <> ALL(...)` + let compat_sso_logins = Query::select() + .expr(Expr::col(( + CompatSsoLogins::Table, + CompatSsoLogins::CompatSessionId, + ))) + .from(CompatSsoLogins::Table) + .take(); + + if auth_type.is_sso_login() { + Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId)) + .eq(Expr::any(compat_sso_logins)) + } else { + Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId)) + .ne(Expr::all(compat_sso_logins)) + } + })) + .and_where_option(filter.device().map(|device| { + Expr::col((CompatSessions::Table, CompatSessions::DeviceId)).eq(device.as_str()) + })) + .build_sqlx(PostgresQueryBuilder); + + let res = sqlx::query_with(&sql, arguments) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(res.rows_affected().try_into().unwrap_or(usize::MAX)) + } + #[tracing::instrument( name = "db.compat_session.list", skip_all, @@ -545,7 +603,7 @@ impl<'c> CompatSessionRepository for PgCompatSessionRepository<'c> { , last_active_ip = COALESCE(t.last_active_ip, compat_sessions.last_active_ip) FROM ( SELECT * - FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[]) + FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[]) AS t(compat_session_id, last_active_at, last_active_ip) ) AS t WHERE compat_sessions.compat_session_id = t.compat_session_id diff --git a/crates/storage-pg/src/oauth2/mod.rs b/crates/storage-pg/src/oauth2/mod.rs index a6b3ca555..fe66cff41 100644 --- a/crates/storage-pg/src/oauth2/mod.rs +++ b/crates/storage-pg/src/oauth2/mod.rs @@ -709,6 +709,37 @@ mod tests { assert_eq!(list.edges.len(), 1); assert_eq!(list.edges[0], session11); assert_eq!(repo.oauth2_session().count(filter).await.unwrap(), 1); + + // Finish all sessions of a client in batch + let affected = repo + .oauth2_session() + .finish_bulk( + &clock, + OAuth2SessionFilter::new() + .for_client(&client1) + .active_only(), + ) + .await + .unwrap(); + assert_eq!(affected, 1); + + // We should have 3 finished sessions + assert_eq!( + repo.oauth2_session() + .count(OAuth2SessionFilter::new().finished_only()) + .await + .unwrap(), + 3 + ); + + // We should have 1 active sessions + assert_eq!( + repo.oauth2_session() + .count(OAuth2SessionFilter::new().active_only()) + .await + .unwrap(), + 1 + ); } /// Test the [`OAuth2DeviceCodeGrantRepository`] implementation diff --git a/crates/storage-pg/src/oauth2/session.rs b/crates/storage-pg/src/oauth2/session.rs index 663ecd352..9eb5f1a4d 100644 --- a/crates/storage-pg/src/oauth2/session.rs +++ b/crates/storage-pg/src/oauth2/session.rs @@ -206,6 +206,51 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> { }) } + #[tracing::instrument( + name = "db.oauth2_session.finish_bulk", + skip_all, + fields( + db.statement, + ), + err, + )] + async fn finish_bulk( + &mut self, + clock: &dyn Clock, + filter: OAuth2SessionFilter<'_>, + ) -> Result { + let finished_at = clock.now(); + let (sql, arguments) = Query::update() + .table(OAuth2Sessions::Table) + .value(OAuth2Sessions::FinishedAt, finished_at) + .and_where_option(filter.user().map(|user| { + Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).eq(Uuid::from(user.id)) + })) + .and_where_option(filter.client().map(|client| { + Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId)) + .eq(Uuid::from(client.id)) + })) + .and_where_option(filter.state().map(|state| { + if state.is_active() { + Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_null() + } else { + Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_not_null() + } + })) + .and_where_option(filter.scope().map(|scope| { + let scope: Vec = scope.iter().map(|s| s.as_str().to_owned()).collect(); + Expr::col((OAuth2Sessions::Table, OAuth2Sessions::ScopeList)).contains(scope) + })) + .build_sqlx(PostgresQueryBuilder); + + let res = sqlx::query_with(&sql, arguments) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(res.rows_affected().try_into().unwrap_or(usize::MAX)) + } + #[tracing::instrument( name = "db.oauth2_session.finish", skip_all, diff --git a/crates/storage-pg/src/user/session.rs b/crates/storage-pg/src/user/session.rs index 0ec57481f..5f9a5da40 100644 --- a/crates/storage-pg/src/user/session.rs +++ b/crates/storage-pg/src/user/session.rs @@ -259,6 +259,43 @@ impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> { Ok(user_session) } + #[tracing::instrument( + name = "db.browser_session.finish_bulk", + skip_all, + fields( + db.statement, + ), + err, + )] + async fn finish_bulk( + &mut self, + clock: &dyn Clock, + filter: mas_storage::user::BrowserSessionFilter<'_>, + ) -> Result { + let finished_at = clock.now(); + let (sql, arguments) = sea_query::Query::update() + .table(UserSessions::Table) + .value(UserSessions::FinishedAt, finished_at) + .and_where_option(filter.user().map(|user| { + Expr::col((UserSessions::Table, UserSessions::UserId)).eq(Uuid::from(user.id)) + })) + .and_where_option(filter.state().map(|state| { + if state.is_active() { + Expr::col((UserSessions::Table, UserSessions::FinishedAt)).is_null() + } else { + Expr::col((UserSessions::Table, UserSessions::FinishedAt)).is_not_null() + } + })) + .build_sqlx(PostgresQueryBuilder); + + let res = sqlx::query_with(&sql, arguments) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(res.rows_affected().try_into().unwrap_or(usize::MAX)) + } + #[tracing::instrument( name = "db.browser_session.list", skip_all, @@ -560,7 +597,7 @@ impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> { , last_active_ip = COALESCE(t.last_active_ip, user_sessions.last_active_ip) FROM ( SELECT * - FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[]) + FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[]) AS t(user_session_id, last_active_at, last_active_ip) ) AS t WHERE user_sessions.user_session_id = t.user_session_id diff --git a/crates/storage-pg/src/user/tests.rs b/crates/storage-pg/src/user/tests.rs index 917ebc854..789e9b4cb 100644 --- a/crates/storage-pg/src/user/tests.rs +++ b/crates/storage-pg/src/user/tests.rs @@ -534,19 +534,23 @@ async fn test_user_password_repo(pool: PgPool) { #[sqlx::test(migrator = "crate::MIGRATOR")] async fn test_user_session(pool: PgPool) { - const USERNAME: &str = "john"; - let mut repo = PgRepository::from_pool(&pool).await.unwrap(); let mut rng = ChaChaRng::seed_from_u64(42); let clock = MockClock::default(); - let user = repo + let alice = repo .user() - .add(&mut rng, &clock, USERNAME.to_owned()) + .add(&mut rng, &clock, "alice".to_owned()) .await .unwrap(); - let all = BrowserSessionFilter::default().for_user(&user); + let bob = repo + .user() + .add(&mut rng, &clock, "bob".to_owned()) + .await + .unwrap(); + + let all = BrowserSessionFilter::default(); let active = all.active_only(); let finished = all.finished_only(); @@ -556,10 +560,10 @@ async fn test_user_session(pool: PgPool) { let session = repo .browser_session() - .add(&mut rng, &clock, &user, None) + .add(&mut rng, &clock, &alice, None) .await .unwrap(); - assert_eq!(session.user.id, user.id); + assert_eq!(session.user.id, alice.id); assert!(session.finished_at.is_none()); assert_eq!(repo.browser_session().count(all).await.unwrap(), 1); @@ -584,7 +588,7 @@ async fn test_user_session(pool: PgPool) { .expect("user session not found"); assert_eq!(session_lookup.id, session.id); - assert_eq!(session_lookup.user.id, user.id); + assert_eq!(session_lookup.user.id, alice.id); assert!(session_lookup.finished_at.is_none()); // Finish the session @@ -616,9 +620,53 @@ async fn test_user_session(pool: PgPool) { .expect("user session not found"); assert_eq!(session_lookup.id, session.id); - assert_eq!(session_lookup.user.id, user.id); + assert_eq!(session_lookup.user.id, alice.id); // This time the session is finished assert!(session_lookup.finished_at.is_some()); + + // Create a bunch of other sessions + for _ in 0..5 { + for user in &[&alice, &bob] { + repo.browser_session() + .add(&mut rng, &clock, user, None) + .await + .unwrap(); + } + } + + let all_alice = BrowserSessionFilter::new().for_user(&alice); + let active_alice = BrowserSessionFilter::new().for_user(&alice).active_only(); + let all_bob = BrowserSessionFilter::new().for_user(&bob); + let active_bob = BrowserSessionFilter::new().for_user(&bob).active_only(); + assert_eq!(repo.browser_session().count(all).await.unwrap(), 11); + assert_eq!(repo.browser_session().count(active).await.unwrap(), 10); + assert_eq!(repo.browser_session().count(finished).await.unwrap(), 1); + assert_eq!(repo.browser_session().count(all_alice).await.unwrap(), 6); + assert_eq!(repo.browser_session().count(active_alice).await.unwrap(), 5); + assert_eq!(repo.browser_session().count(all_bob).await.unwrap(), 5); + assert_eq!(repo.browser_session().count(active_bob).await.unwrap(), 5); + + // Finish all the sessions for alice + let affected = repo + .browser_session() + .finish_bulk(&clock, active_alice) + .await + .unwrap(); + assert_eq!(affected, 5); + assert_eq!(repo.browser_session().count(all_alice).await.unwrap(), 6); + assert_eq!(repo.browser_session().count(active_alice).await.unwrap(), 0); + assert_eq!(repo.browser_session().count(finished).await.unwrap(), 6); + + // Finish all the sessions for bob + let affected = repo + .browser_session() + .finish_bulk(&clock, active_bob) + .await + .unwrap(); + assert_eq!(affected, 5); + assert_eq!(repo.browser_session().count(all_bob).await.unwrap(), 5); + assert_eq!(repo.browser_session().count(active_bob).await.unwrap(), 0); + assert_eq!(repo.browser_session().count(finished).await.unwrap(), 11); } #[sqlx::test(migrator = "crate::MIGRATOR")] diff --git a/crates/storage/src/compat/session.rs b/crates/storage/src/compat/session.rs index 6a0b4ab54..227399ac3 100644 --- a/crates/storage/src/compat/session.rs +++ b/crates/storage/src/compat/session.rs @@ -209,6 +209,24 @@ pub trait CompatSessionRepository: Send + Sync { compat_session: CompatSession, ) -> Result; + /// Mark all the [`CompatSession`] matching the given filter as finished + /// + /// Returns the number of sessions affected + /// + /// # Parameters + /// + /// * `clock`: The clock used to generate timestamps + /// * `filter`: The filter to apply + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn finish_bulk( + &mut self, + clock: &dyn Clock, + filter: CompatSessionFilter<'_>, + ) -> Result; + /// List [`CompatSession`] with the given filter and pagination /// /// Returns a page of compat sessions, with the associated SSO logins if any @@ -289,6 +307,12 @@ repository_impl!(CompatSessionRepository: compat_session: CompatSession, ) -> Result; + async fn finish_bulk( + &mut self, + clock: &dyn Clock, + filter: CompatSessionFilter<'_>, + ) -> Result; + async fn list( &mut self, filter: CompatSessionFilter<'_>, diff --git a/crates/storage/src/oauth2/session.rs b/crates/storage/src/oauth2/session.rs index 60a130e0c..46e969387 100644 --- a/crates/storage/src/oauth2/session.rs +++ b/crates/storage/src/oauth2/session.rs @@ -1,4 +1,4 @@ -// Copyright 2022, 2023 The Matrix.org Foundation C.I.C. +// Copyright 2022-2024 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -245,6 +245,24 @@ pub trait OAuth2SessionRepository: Send + Sync { async fn finish(&mut self, clock: &dyn Clock, session: Session) -> Result; + /// Mark all the [`Session`] matching the given filter as finished + /// + /// Returns the number of sessions affected + /// + /// # Parameters + /// + /// * `clock`: The clock used to generate timestamps + /// * `filter`: The filter parameters + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn finish_bulk( + &mut self, + clock: &dyn Clock, + filter: OAuth2SessionFilter<'_>, + ) -> Result; + /// List [`Session`]s matching the given filter and pagination parameters /// /// # Parameters @@ -333,6 +351,12 @@ repository_impl!(OAuth2SessionRepository: async fn finish(&mut self, clock: &dyn Clock, session: Session) -> Result; + async fn finish_bulk( + &mut self, + clock: &dyn Clock, + filter: OAuth2SessionFilter<'_>, + ) -> Result; + async fn list( &mut self, filter: OAuth2SessionFilter<'_>, diff --git a/crates/storage/src/user/session.rs b/crates/storage/src/user/session.rs index fa1d1763e..a10c65b67 100644 --- a/crates/storage/src/user/session.rs +++ b/crates/storage/src/user/session.rs @@ -148,6 +148,24 @@ pub trait BrowserSessionRepository: Send + Sync { user_session: BrowserSession, ) -> Result; + /// Mark all the [`BrowserSession`] matching the given filter as finished + /// + /// Returns the number of sessions affected + /// + /// # Parameters + /// + /// * `clock`: The clock used to generate timestamps + /// * `filter`: The filter parameters + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn finish_bulk( + &mut self, + clock: &dyn Clock, + filter: BrowserSessionFilter<'_>, + ) -> Result; + /// List [`BrowserSession`] with the given filter and pagination /// /// # Parameters @@ -262,6 +280,12 @@ repository_impl!(BrowserSessionRepository: user_session: BrowserSession, ) -> Result; + async fn finish_bulk( + &mut self, + clock: &dyn Clock, + filter: BrowserSessionFilter<'_>, + ) -> Result; + async fn list( &mut self, filter: BrowserSessionFilter<'_>, diff --git a/crates/tasks/src/user.rs b/crates/tasks/src/user.rs index fe64b6c16..7fd711ff8 100644 --- a/crates/tasks/src/user.rs +++ b/crates/tasks/src/user.rs @@ -15,8 +15,10 @@ use anyhow::Context; use apalis_core::{context::JobContext, executor::TokioExecutor, monitor::Monitor}; use mas_storage::{ + compat::CompatSessionFilter, job::{DeactivateUserJob, JobWithSpanContext, ReactivateUserJob}, - user::UserRepository, + oauth2::OAuth2SessionFilter, + user::{BrowserSessionFilter, UserRepository}, RepositoryAccess, }; use tracing::info; @@ -52,7 +54,33 @@ async fn deactivate_user( .await .context("Failed to lock user")?; - // TODO: delete the sessions & access tokens + // Kill all sessions for the user + let n = repo + .browser_session() + .finish_bulk( + &clock, + BrowserSessionFilter::new().for_user(&user).active_only(), + ) + .await?; + info!(affected = n, "Killed all browser sessions for user"); + + let n = repo + .oauth2_session() + .finish_bulk( + &clock, + OAuth2SessionFilter::new().for_user(&user).active_only(), + ) + .await?; + info!(affected = n, "Killed all OAuth 2.0 sessions for user"); + + let n = repo + .compat_session() + .finish_bulk( + &clock, + CompatSessionFilter::new().for_user(&user).active_only(), + ) + .await?; + info!(affected = n, "Killed all compatibility sessions for user"); // Before calling back to the homeserver, commit the changes to the database, as // we want the user to be locked out as soon as possible