From 13c7e0ebda88d3676c8bc403fa4c5bd88a17023a Mon Sep 17 00:00:00 2001 From: Valere Date: Fri, 8 Dec 2023 15:21:07 +0100 Subject: [PATCH] Element-R: Refactor per-session key backup download (#3929) * initial commit * new interation test * more comments * fix test, quick refactor on request version * cleaning and logs * fix type * cleaning * remove delegate stuff * remove events and use timer mocks * fix import * ts ignore in tests * Quick cleaning * code review * Use Errors instead of Results * cleaning * review * remove forceCheck as not useful * bad naming * inline pauseLoop * mark as paused in finally * code review * post merge fix * rename KeyDownloadRateLimit * use same config in loop and pass along --- spec/integ/crypto/megolm-backup.spec.ts | 143 ++++- .../PerSessionKeyBackupDownloader.spec.ts | 598 ++++++++++++++++++ .../PerSessionKeyBackupDownloader.ts | 474 ++++++++++++++ src/rust-crypto/backup.ts | 54 +- src/rust-crypto/rust-crypto.ts | 131 +--- 5 files changed, 1289 insertions(+), 111 deletions(-) create mode 100644 spec/unit/rust-crypto/PerSessionKeyBackupDownloader.spec.ts create mode 100644 src/rust-crypto/PerSessionKeyBackupDownloader.ts diff --git a/spec/integ/crypto/megolm-backup.spec.ts b/spec/integ/crypto/megolm-backup.spec.ts index d12b7e5486e..d7f8644c8e0 100644 --- a/spec/integ/crypto/megolm-backup.spec.ts +++ b/spec/integ/crypto/megolm-backup.spec.ts @@ -18,7 +18,7 @@ import fetchMock from "fetch-mock-jest"; import "fake-indexeddb/auto"; import { IDBFactory } from "fake-indexeddb"; -import { createClient, CryptoEvent, ICreateClientOpts, MatrixClient, TypedEventEmitter } from "../../../src"; +import { createClient, CryptoEvent, ICreateClientOpts, IEvent, MatrixClient, TypedEventEmitter } from "../../../src"; import { SyncResponder } from "../../test-utils/SyncResponder"; import { E2EKeyReceiver } from "../../test-utils/E2EKeyReceiver"; import { E2EKeyResponder } from "../../test-utils/E2EKeyResponder"; @@ -34,6 +34,7 @@ import * as testData from "../../test-utils/test-data"; import { KeyBackupInfo } from "../../../src/crypto-api/keybackup"; import { IKeyBackup } from "../../../src/crypto/backup"; import { flushPromises } from "../../test-utils/flushPromises"; +import { defer, IDeferred } from "../../../src/utils"; const ROOM_ID = testData.TEST_ROOM_ID; @@ -888,6 +889,146 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("megolm-keys backup (%s)", (backe }); }); + describe("Backup Changed from other sessions", () => { + beforeEach(async () => { + fetchMock.get("path:/_matrix/client/v3/room_keys/version", testData.SIGNED_BACKUP_DATA); + + // ignore requests to send room key requests + fetchMock.put("express:/_matrix/client/v3/sendToDevice/m.room_key_request/:request_id", {}); + + aliceClient = await initTestClient(); + const aliceCrypto = aliceClient.getCrypto()!; + await aliceCrypto.storeSessionBackupPrivateKey( + Buffer.from(testData.BACKUP_DECRYPTION_KEY_BASE64, "base64"), + testData.SIGNED_BACKUP_DATA.version!, + ); + + // start after saving the private key + await aliceClient.startClient(); + + // tell Alice to trust the dummy device that signed the backup, and re-check the backup. + // XXX: should we automatically re-check after a device becomes verified? + await waitForDeviceList(); + await aliceClient.getCrypto()!.setDeviceVerified(testData.TEST_USER_ID, testData.TEST_DEVICE_ID); + await aliceClient.getCrypto()!.checkKeyBackupAndEnable(); + }); + + // let aliceClient: MatrixClient; + + const SYNC_RESPONSE = { + next_batch: 1, + rooms: { join: { [ROOM_ID]: { timeline: { events: [testData.ENCRYPTED_EVENT] } } } }, + }; + + it("If current backup has changed, the manager should switch to the new one on UTD", async () => { + // ===== + // First ensure that the client checks for keys using the backup version 1 + /// ===== + + fetchMock.get( + "express:/_matrix/client/v3/room_keys/keys/:room_id/:session_id", + (url, request) => { + // check that the version is correct + const version = new URLSearchParams(new URL(url).search).get("version"); + if (version == "1") { + return testData.CURVE25519_KEY_BACKUP_DATA; + } else { + return { + status: 403, + body: { + current_version: "1", + errcode: "M_WRONG_ROOM_KEYS_VERSION", + error: "Wrong backup version.", + }, + }; + } + }, + { overwriteRoutes: true }, + ); + + // Send Alice a message that she won't be able to decrypt, and check that she fetches the key from the backup. + syncResponder.sendOrQueueSyncResponse(SYNC_RESPONSE); + await syncPromise(aliceClient); + + const room = aliceClient.getRoom(ROOM_ID)!; + const event = room.getLiveTimeline().getEvents()[0]; + await advanceTimersUntil(awaitDecryption(event, { waitOnDecryptionFailure: true })); + + expect(event.getContent()).toEqual(testData.CLEAR_EVENT.content); + + // ===== + // Second suppose now that the backup has changed to version 2 + /// ===== + + const newBackup = { + ...testData.SIGNED_BACKUP_DATA, + version: "2", + }; + + fetchMock.get("path:/_matrix/client/v3/room_keys/version", newBackup, { overwriteRoutes: true }); + // suppose the new key is now known + const aliceCrypto = aliceClient.getCrypto()!; + await aliceCrypto.storeSessionBackupPrivateKey( + Buffer.from(testData.BACKUP_DECRYPTION_KEY_BASE64, "base64"), + newBackup.version, + ); + + // A check backup should happen at some point + await aliceCrypto.checkKeyBackupAndEnable(); + + const awaitHasQueriedNewBackup: IDeferred = defer(); + + fetchMock.get( + "express:/_matrix/client/v3/room_keys/keys/:room_id/:session_id", + (url, request) => { + // check that the version is correct + const version = new URLSearchParams(new URL(url).search).get("version"); + if (version == newBackup.version) { + awaitHasQueriedNewBackup.resolve(); + return testData.CURVE25519_KEY_BACKUP_DATA; + } else { + // awaitHasQueriedOldBackup.resolve(); + return { + status: 403, + body: { + current_version: "2", + errcode: "M_WRONG_ROOM_KEYS_VERSION", + error: "Wrong backup version.", + }, + }; + } + }, + { overwriteRoutes: true }, + ); + + // Send Alice a message that she won't be able to decrypt, and check that she fetches the key from the new backup. + const newMessage: Partial = { + type: "m.room.encrypted", + room_id: "!room:id", + sender: "@alice:localhost", + content: { + algorithm: "m.megolm.v1.aes-sha2", + ciphertext: + "AwgAEpABKvf9FqPW52zeHfeVTn90a3jlBLlx7g6VDEkc2089RQUJoWpSJRiK13E83rN41wgGFJccyfoCr7ZDGJeuGYMGETTrgnLQhLs6JmyPf37JYkzxW8uS8rGUKEqTFQriKhibHVLvVacOlSIObUiKU/V3r176XuixqZF/4eyK9A22JNpInbgI10ZUT6LnApH9LR3FpZbE2zImf1uNPuvp7r0xQbW7CcJjqpH+qTPBD5zFdFnMkc2SnbXCsIOaX11Dm0krWfQz7iA26ZnI1nyZnyh7XPrCnJCRsuQH", + device_id: "WVMJGTSSVB", + sender_key: "E5RiY/YCIrHWaF4u416CqvblC6udK2jt9SJ/h1QeLS0", + session_id: "ybnW+LGdUhoS4fHm1DAEphukO3sZ1GCqZD7UQz7L+GA", + }, + event_id: "$event2", + origin_server_ts: 1507753887000, + }; + + const nextSyncResponse = { + next_batch: 2, + rooms: { join: { [ROOM_ID]: { timeline: { events: [newMessage] } } } }, + }; + syncResponder.sendOrQueueSyncResponse(nextSyncResponse); + await syncPromise(aliceClient); + + await awaitHasQueriedNewBackup.promise; + }); + }); + /** make sure that the client knows about the dummy device */ async function waitForDeviceList(): Promise { // Completing the initial sync will make the device list download outdated device lists (of which our own diff --git a/spec/unit/rust-crypto/PerSessionKeyBackupDownloader.spec.ts b/spec/unit/rust-crypto/PerSessionKeyBackupDownloader.spec.ts new file mode 100644 index 00000000000..8b1b0f75c2c --- /dev/null +++ b/spec/unit/rust-crypto/PerSessionKeyBackupDownloader.spec.ts @@ -0,0 +1,598 @@ +/* +Copyright 2023 The Matrix.org Foundation C.I.C. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +import { Mocked, SpyInstance } from "jest-mock"; +import * as RustSdkCryptoJs from "@matrix-org/matrix-sdk-crypto-wasm"; +import { OlmMachine } from "@matrix-org/matrix-sdk-crypto-wasm"; +import fetchMock from "fetch-mock-jest"; + +import { PerSessionKeyBackupDownloader } from "../../../src/rust-crypto/PerSessionKeyBackupDownloader"; +import { logger } from "../../../src/logger"; +import { defer, IDeferred } from "../../../src/utils"; +import { RustBackupCryptoEventMap, RustBackupCryptoEvents, RustBackupManager } from "../../../src/rust-crypto/backup"; +import * as TestData from "../../test-utils/test-data"; +import { + ConnectionError, + CryptoEvent, + HttpApiEvent, + HttpApiEventHandlerMap, + IHttpOpts, + IMegolmSessionData, + MatrixHttpApi, + TypedEventEmitter, +} from "../../../src"; +import * as testData from "../../test-utils/test-data"; +import { BackupDecryptor } from "../../../src/common-crypto/CryptoBackend"; +import { KeyBackupSession } from "../../../src/crypto-api/keybackup"; + +describe("PerSessionKeyBackupDownloader", () => { + /** The downloader under test */ + let downloader: PerSessionKeyBackupDownloader; + + const mockCipherKey: Mocked = {} as unknown as Mocked; + + // matches the const in PerSessionKeyBackupDownloader + const BACKOFF_TIME = 5000; + + let mockEmitter: TypedEventEmitter; + let mockHttp: MatrixHttpApi; + let mockRustBackupManager: Mocked; + let mockOlmMachine: Mocked; + let mockBackupDecryptor: Mocked; + + let expectedSession: { [roomId: string]: { [sessionId: string]: IDeferred } }; + + function expectSessionImported(roomId: string, sessionId: string) { + const deferred = defer(); + if (!expectedSession[roomId]) { + expectedSession[roomId] = {}; + } + expectedSession[roomId][sessionId] = deferred; + return deferred.promise; + } + + function mockClearSession(sessionId: string): Mocked { + return { + session_id: sessionId, + } as unknown as Mocked; + } + + beforeEach(async () => { + mockEmitter = new TypedEventEmitter() as TypedEventEmitter; + + mockHttp = new MatrixHttpApi(new TypedEventEmitter(), { + baseUrl: "http://server/", + prefix: "", + onlyData: true, + }); + + mockBackupDecryptor = { + decryptSessions: jest.fn(), + } as unknown as Mocked; + + mockBackupDecryptor.decryptSessions.mockImplementation(async (ciphertexts) => { + const sessionId = Object.keys(ciphertexts)[0]; + return [mockClearSession(sessionId)]; + }); + + mockRustBackupManager = { + getActiveBackupVersion: jest.fn(), + requestKeyBackupVersion: jest.fn(), + importBackedUpRoomKeys: jest.fn(), + createBackupDecryptor: jest.fn().mockReturnValue(mockBackupDecryptor), + on: jest.fn().mockImplementation((event, listener) => { + mockEmitter.on(event, listener); + }), + off: jest.fn().mockImplementation((event, listener) => { + mockEmitter.off(event, listener); + }), + } as unknown as Mocked; + + mockOlmMachine = { + getBackupKeys: jest.fn(), + } as unknown as Mocked; + + downloader = new PerSessionKeyBackupDownloader(logger, mockOlmMachine, mockHttp, mockRustBackupManager); + + expectedSession = {}; + mockRustBackupManager.importBackedUpRoomKeys.mockImplementation(async (keys) => { + const roomId = keys[0].room_id; + const sessionId = keys[0].session_id; + const deferred = expectedSession[roomId] && expectedSession[roomId][sessionId]; + if (deferred) { + deferred.resolve(); + } + }); + + jest.useFakeTimers(); + }); + + afterEach(() => { + expectedSession = {}; + downloader.stop(); + fetchMock.mockReset(); + jest.useRealTimers(); + }); + + describe("Given valid backup available", () => { + beforeEach(async () => { + mockRustBackupManager.getActiveBackupVersion.mockResolvedValue(TestData.SIGNED_BACKUP_DATA.version!); + mockOlmMachine.getBackupKeys.mockResolvedValue({ + backupVersion: TestData.SIGNED_BACKUP_DATA.version!, + decryptionKey: RustSdkCryptoJs.BackupDecryptionKey.fromBase64(TestData.BACKUP_DECRYPTION_KEY_BASE64), + } as unknown as RustSdkCryptoJs.BackupKeys); + + mockRustBackupManager.requestKeyBackupVersion.mockResolvedValue(TestData.SIGNED_BACKUP_DATA); + }); + + it("Should download and import a missing key from backup", async () => { + const awaitKeyImported = defer(); + const roomId = "!roomId"; + const sessionId = "sessionId"; + const expectAPICall = new Promise((resolve) => { + fetchMock.get(`path:/_matrix/client/v3/room_keys/keys/${roomId}/${sessionId}`, (url, request) => { + resolve(); + return TestData.CURVE25519_KEY_BACKUP_DATA; + }); + }); + mockRustBackupManager.importBackedUpRoomKeys.mockImplementation(async (keys) => { + awaitKeyImported.resolve(); + }); + mockBackupDecryptor.decryptSessions.mockResolvedValue([TestData.MEGOLM_SESSION_DATA]); + + downloader.onDecryptionKeyMissingError(roomId, sessionId); + + await expectAPICall; + await awaitKeyImported.promise; + expect(mockRustBackupManager.createBackupDecryptor).toHaveBeenCalledTimes(1); + }); + + it("Should not hammer the backup if the key is requested repeatedly", async () => { + const blockOnServerRequest = defer(); + + fetchMock.get(`express:/_matrix/client/v3/room_keys/keys/!roomId/:session_id`, async (url, request) => { + await blockOnServerRequest.promise; + return [mockCipherKey]; + }); + + const awaitKey2Imported = defer(); + + mockRustBackupManager.importBackedUpRoomKeys.mockImplementation(async (keys) => { + if (keys[0].session_id === "sessionId2") { + awaitKey2Imported.resolve(); + } + }); + + // @ts-ignore access to private function + const spy = jest.spyOn(downloader, "queryKeyBackup"); + + // Call 3 times for same key + downloader.onDecryptionKeyMissingError("!roomId", "sessionId"); + downloader.onDecryptionKeyMissingError("!roomId", "sessionId"); + downloader.onDecryptionKeyMissingError("!roomId", "sessionId"); + + // Call again for a different key + downloader.onDecryptionKeyMissingError("!roomId", "sessionId2"); + + // Allow the first server request to complete + blockOnServerRequest.resolve(); + + await awaitKey2Imported.promise; + expect(spy).toHaveBeenCalledTimes(2); + }); + + it("should continue to next key if current not in backup", async () => { + fetchMock.get(`path:/_matrix/client/v3/room_keys/keys/!roomA/sessionA0`, { + status: 404, + body: { + errcode: "M_NOT_FOUND", + error: "No backup found", + }, + }); + fetchMock.get(`path:/_matrix/client/v3/room_keys/keys/!roomA/sessionA1`, mockCipherKey); + + // @ts-ignore access to private function + const spy: SpyInstance = jest.spyOn(downloader, "queryKeyBackup"); + + const expectImported = expectSessionImported("!roomA", "sessionA1"); + + downloader.onDecryptionKeyMissingError("!roomA", "sessionA0"); + await jest.runAllTimersAsync(); + expect(spy).toHaveBeenCalledTimes(1); + expect(spy).toHaveLastReturnedWith(Promise.resolve({ ok: false, error: "MISSING_DECRYPTION_KEY" })); + + downloader.onDecryptionKeyMissingError("!roomA", "sessionA1"); + await jest.runAllTimersAsync(); + expect(spy).toHaveBeenCalledTimes(2); + + await expectImported; + }); + + it("Should not query repeatedly for a key not in backup", async () => { + fetchMock.get(`path:/_matrix/client/v3/room_keys/keys/!roomA/sessionA0`, { + status: 404, + body: { + errcode: "M_NOT_FOUND", + error: "No backup found", + }, + }); + + // @ts-ignore access to private function + const spy: SpyInstance = jest.spyOn(downloader, "queryKeyBackup"); + + downloader.onDecryptionKeyMissingError("!roomA", "sessionA0"); + await jest.runAllTimersAsync(); + + expect(spy).toHaveBeenCalledTimes(1); + const returnedPromise = spy.mock.results[0].value; + await expect(returnedPromise).rejects.toThrow("Failed to get key from backup: MISSING_DECRYPTION_KEY"); + + // Should not query again for a key not in backup + downloader.onDecryptionKeyMissingError("!roomA", "sessionA0"); + await jest.runAllTimersAsync(); + + expect(spy).toHaveBeenCalledTimes(1); + + // advance time to retry + jest.advanceTimersByTime(BACKOFF_TIME + 10); + + downloader.onDecryptionKeyMissingError("!roomA", "sessionA0"); + await jest.runAllTimersAsync(); + + expect(spy).toHaveBeenCalledTimes(2); + await expect(spy.mock.results[1].value).rejects.toThrow( + "Failed to get key from backup: MISSING_DECRYPTION_KEY", + ); + }); + + it("Should stop properly", async () => { + // Simulate a call to stop while request is in flight + const blockOnServerRequest = defer(); + const requestRoomKeyCalled = defer(); + + // Mock the request to block + fetchMock.get(`express:/_matrix/client/v3/room_keys/keys/:roomId/:sessionId`, async (url, request) => { + requestRoomKeyCalled.resolve(); + await blockOnServerRequest.promise; + return mockCipherKey; + }); + + downloader.onDecryptionKeyMissingError("!roomA", "sessionA0"); + downloader.onDecryptionKeyMissingError("!roomA", "sessionA1"); + downloader.onDecryptionKeyMissingError("!roomA", "sessionA2"); + downloader.onDecryptionKeyMissingError("!roomA", "sessionA3"); + + await requestRoomKeyCalled.promise; + downloader.stop(); + + blockOnServerRequest.resolve(); + + // let the first request complete + await jest.runAllTimersAsync(); + + expect(mockRustBackupManager.importBackedUpRoomKeys).not.toHaveBeenCalled(); + expect( + fetchMock.calls(`express:/_matrix/client/v3/room_keys/keys/:roomId/:sessionId`).length, + ).toStrictEqual(1); + }); + }); + + describe("Given no usable backup available", () => { + let getConfigSpy: SpyInstance; + + beforeEach(async () => { + mockRustBackupManager.getActiveBackupVersion.mockResolvedValue(null); + mockOlmMachine.getBackupKeys.mockResolvedValue(null); + + // @ts-ignore access to private function + getConfigSpy = jest.spyOn(downloader, "getOrCreateBackupConfiguration"); + }); + + it("Should not query server if no backup", async () => { + fetchMock.get("path:/_matrix/client/v3/room_keys/version", { + status: 404, + body: { errcode: "M_NOT_FOUND", error: "No current backup version." }, + }); + + downloader.onDecryptionKeyMissingError("!roomId", "sessionId"); + + await jest.runAllTimersAsync(); + + expect(getConfigSpy).toHaveBeenCalledTimes(1); + expect(getConfigSpy).toHaveReturnedWith(Promise.resolve(null)); + }); + + it("Should not query server if backup not active", async () => { + // there is a backup + fetchMock.get("path:/_matrix/client/v3/room_keys/version", testData.SIGNED_BACKUP_DATA); + + // but it's not trusted + mockRustBackupManager.getActiveBackupVersion.mockResolvedValue(null); + + downloader.onDecryptionKeyMissingError("!roomId", "sessionId"); + + await jest.runAllTimersAsync(); + + expect(getConfigSpy).toHaveBeenCalledTimes(1); + expect(getConfigSpy).toHaveReturnedWith(Promise.resolve(null)); + }); + + it("Should stop if backup key is not cached", async () => { + // there is a backup + fetchMock.get("path:/_matrix/client/v3/room_keys/version", testData.SIGNED_BACKUP_DATA); + // it is trusted + mockRustBackupManager.getActiveBackupVersion.mockResolvedValue(TestData.SIGNED_BACKUP_DATA.version!); + // but the key is not cached + mockOlmMachine.getBackupKeys.mockResolvedValue(null); + + downloader.onDecryptionKeyMissingError("!roomId", "sessionId"); + + await jest.runAllTimersAsync(); + + expect(getConfigSpy).toHaveBeenCalledTimes(1); + expect(getConfigSpy).toHaveReturnedWith(Promise.resolve(null)); + }); + + it("Should stop if backup key cached as wrong version", async () => { + // there is a backup + fetchMock.get("path:/_matrix/client/v3/room_keys/version", testData.SIGNED_BACKUP_DATA); + // it is trusted + mockRustBackupManager.getActiveBackupVersion.mockResolvedValue(TestData.SIGNED_BACKUP_DATA.version!); + // but the cached key has the wrong version + mockOlmMachine.getBackupKeys.mockResolvedValue({ + backupVersion: "0", + decryptionKey: RustSdkCryptoJs.BackupDecryptionKey.fromBase64(TestData.BACKUP_DECRYPTION_KEY_BASE64), + } as unknown as RustSdkCryptoJs.BackupKeys); + + downloader.onDecryptionKeyMissingError("!roomId", "sessionId"); + + await jest.runAllTimersAsync(); + + expect(getConfigSpy).toHaveBeenCalledTimes(1); + expect(getConfigSpy).toHaveReturnedWith(Promise.resolve(null)); + }); + + it("Should stop if backup key version does not match the active one", async () => { + // there is a backup + fetchMock.get("path:/_matrix/client/v3/room_keys/version", testData.SIGNED_BACKUP_DATA); + // The sdk is out of sync, the trusted version is the old one + mockRustBackupManager.getActiveBackupVersion.mockResolvedValue("0"); + // key for old backup cached + mockOlmMachine.getBackupKeys.mockResolvedValue({ + backupVersion: "0", + decryptionKey: RustSdkCryptoJs.BackupDecryptionKey.fromBase64(TestData.BACKUP_DECRYPTION_KEY_BASE64), + } as unknown as RustSdkCryptoJs.BackupKeys); + + downloader.onDecryptionKeyMissingError("!roomId", "sessionId"); + + await jest.runAllTimersAsync(); + + expect(getConfigSpy).toHaveBeenCalledTimes(1); + expect(getConfigSpy).toHaveReturnedWith(Promise.resolve(null)); + }); + }); + + describe("Given Backup state update", () => { + it("After initial sync, when backup becomes trusted it should request keys for past requests", async () => { + // there is a backup + mockRustBackupManager.requestKeyBackupVersion.mockResolvedValue(TestData.SIGNED_BACKUP_DATA); + + // but at this point it's not trusted and we don't have the key + mockRustBackupManager.getActiveBackupVersion.mockResolvedValue(null); + mockOlmMachine.getBackupKeys.mockResolvedValue(null); + + fetchMock.get(`express:/_matrix/client/v3/room_keys/keys/:roomId/:sessionId`, mockCipherKey); + + const a0Imported = expectSessionImported("!roomA", "sessionA0"); + const a1Imported = expectSessionImported("!roomA", "sessionA1"); + const b1Imported = expectSessionImported("!roomB", "sessionB1"); + const c1Imported = expectSessionImported("!roomC", "sessionC1"); + + // During initial sync several keys are requested + downloader.onDecryptionKeyMissingError("!roomA", "sessionA0"); + downloader.onDecryptionKeyMissingError("!roomA", "sessionA1"); + downloader.onDecryptionKeyMissingError("!roomB", "sessionB1"); + downloader.onDecryptionKeyMissingError("!roomC", "sessionC1"); + await jest.runAllTimersAsync(); + + // @ts-ignore access to private property + expect(downloader.hasConfigurationProblem).toEqual(true); + + // Now the backup becomes trusted + mockRustBackupManager.getActiveBackupVersion.mockResolvedValue(TestData.SIGNED_BACKUP_DATA.version!); + // And we have the key in cache + mockOlmMachine.getBackupKeys.mockResolvedValue({ + backupVersion: TestData.SIGNED_BACKUP_DATA.version!, + decryptionKey: RustSdkCryptoJs.BackupDecryptionKey.fromBase64(TestData.BACKUP_DECRYPTION_KEY_BASE64), + } as unknown as RustSdkCryptoJs.BackupKeys); + + // In that case the sdk would fire a backup status update + mockEmitter.emit(CryptoEvent.KeyBackupStatus, true); + + await jest.runAllTimersAsync(); + + await a0Imported; + await a1Imported; + await b1Imported; + await c1Imported; + }); + }); + + describe("Error cases", () => { + beforeEach(async () => { + // there is a backup + mockRustBackupManager.requestKeyBackupVersion.mockResolvedValue(TestData.SIGNED_BACKUP_DATA); + // It's trusted + mockRustBackupManager.getActiveBackupVersion.mockResolvedValue(TestData.SIGNED_BACKUP_DATA.version!); + // And we have the key in cache + mockOlmMachine.getBackupKeys.mockResolvedValue({ + backupVersion: TestData.SIGNED_BACKUP_DATA.version!, + decryptionKey: RustSdkCryptoJs.BackupDecryptionKey.fromBase64(TestData.BACKUP_DECRYPTION_KEY_BASE64), + } as unknown as RustSdkCryptoJs.BackupKeys); + }); + + it("Should wait on rate limit error", async () => { + // simulate rate limit error + fetchMock.get( + `express:/_matrix/client/v3/room_keys/keys/:roomId/:sessionId`, + { + status: 429, + body: { + errcode: "M_LIMIT_EXCEEDED", + error: "Too many requests", + retry_after_ms: 5000, + }, + }, + { overwriteRoutes: true }, + ); + + const keyImported = expectSessionImported("!roomA", "sessionA0"); + + // @ts-ignore + const originalImplementation = downloader.queryKeyBackup.bind(downloader); + + // @ts-ignore access to private function + const keyQuerySpy: SpyInstance = jest.spyOn(downloader, "queryKeyBackup"); + const rateDeferred = defer(); + + keyQuerySpy.mockImplementation( + // @ts-ignore + async (targetRoomId: string, targetSessionId: string, configuration: any) => { + try { + return await originalImplementation(targetRoomId, targetSessionId, configuration); + } catch (err: any) { + if (err.name === "KeyDownloadRateLimitError") { + rateDeferred.resolve(); + } + throw err; + } + }, + ); + downloader.onDecryptionKeyMissingError("!roomA", "sessionA0"); + + await rateDeferred.promise; + expect(keyQuerySpy).toHaveBeenCalledTimes(1); + await expect(keyQuerySpy.mock.results[0].value).rejects.toThrow( + "Failed to get key from backup: rate limited", + ); + + fetchMock.get(`express:/_matrix/client/v3/room_keys/keys/:roomId/:sessionId`, mockCipherKey, { + overwriteRoutes: true, + }); + + // Advance less than the retry_after_ms + jest.advanceTimersByTime(100); + // let any pending callbacks in PromiseJobs run + await Promise.resolve(); + // no additional call should have been made + expect(keyQuerySpy).toHaveBeenCalledTimes(1); + + // The loop should resume after the retry_after_ms + jest.advanceTimersByTime(5000); + // let any pending callbacks in PromiseJobs run + await Promise.resolve(); + + await keyImported; + expect(keyQuerySpy).toHaveBeenCalledTimes(2); + }); + + it("After a network error the same key is retried", async () => { + // simulate connectivity error + fetchMock.get(`express:/_matrix/client/v3/room_keys/keys/:roomId/:sessionId`, () => { + throw new ConnectionError("fetch failed", new Error("fetch failed")); + }); + + // @ts-ignore + const originalImplementation = downloader.queryKeyBackup.bind(downloader); + + // @ts-ignore + const keyQuerySpy: SpyInstance = jest.spyOn(downloader, "queryKeyBackup"); + const errorDeferred = defer(); + + keyQuerySpy.mockImplementation( + // @ts-ignore + async (targetRoomId: string, targetSessionId: string, configuration: any) => { + try { + return await originalImplementation(targetRoomId, targetSessionId, configuration); + } catch (err: any) { + if (err.name === "KeyDownloadError") { + errorDeferred.resolve(); + } + throw err; + } + }, + ); + const keyImported = expectSessionImported("!roomA", "sessionA0"); + + downloader.onDecryptionKeyMissingError("!roomA", "sessionA0"); + await errorDeferred.promise; + await Promise.resolve(); + + await expect(keyQuerySpy.mock.results[0].value).rejects.toThrow( + "Failed to get key from backup: NETWORK_ERROR", + ); + + fetchMock.get(`express:/_matrix/client/v3/room_keys/keys/:roomId/:sessionId`, mockCipherKey, { + overwriteRoutes: true, + }); + + // Advance less than the retry_after_ms + jest.advanceTimersByTime(100); + // let any pending callbacks in PromiseJobs run + await Promise.resolve(); + // no additional call should have been made + expect(keyQuerySpy).toHaveBeenCalledTimes(1); + + // The loop should resume after the retry_after_ms + jest.advanceTimersByTime(BACKOFF_TIME + 100); + await Promise.resolve(); + + await keyImported; + }); + + it("On Unknown error on import skip the key and continue", async () => { + const keyImported = defer(); + mockRustBackupManager.importBackedUpRoomKeys + .mockImplementationOnce(async () => { + throw new Error("Didn't work"); + }) + .mockImplementationOnce(async (sessions) => { + const roomId = sessions[0].room_id; + const sessionId = sessions[0].session_id; + if (roomId === "!roomA" && sessionId === "sessionA1") { + keyImported.resolve(); + } + return; + }); + + fetchMock.get(`express:/_matrix/client/v3/room_keys/keys/:roomId/:sessionId`, mockCipherKey, { + overwriteRoutes: true, + }); + + // @ts-ignore access to private function + const keyQuerySpy: SpyInstance = jest.spyOn(downloader, "queryKeyBackup"); + + downloader.onDecryptionKeyMissingError("!roomA", "sessionA0"); + downloader.onDecryptionKeyMissingError("!roomA", "sessionA1"); + await jest.runAllTimersAsync(); + + await keyImported.promise; + + expect(keyQuerySpy).toHaveBeenCalledTimes(2); + expect(mockRustBackupManager.importBackedUpRoomKeys).toHaveBeenCalledTimes(2); + }); + }); +}); diff --git a/src/rust-crypto/PerSessionKeyBackupDownloader.ts b/src/rust-crypto/PerSessionKeyBackupDownloader.ts new file mode 100644 index 00000000000..c8283c65488 --- /dev/null +++ b/src/rust-crypto/PerSessionKeyBackupDownloader.ts @@ -0,0 +1,474 @@ +/* +Copyright 2023 The Matrix.org Foundation C.I.C. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +import * as RustSdkCryptoJs from "@matrix-org/matrix-sdk-crypto-wasm"; +import { OlmMachine } from "@matrix-org/matrix-sdk-crypto-wasm"; + +import { Curve25519AuthData, KeyBackupSession } from "../crypto-api/keybackup"; +import { Logger } from "../logger"; +import { ClientPrefix, IHttpOpts, MatrixError, MatrixHttpApi, Method } from "../http-api"; +import { RustBackupManager } from "./backup"; +import { CryptoEvent } from "../matrix"; +import { encodeUri, sleep } from "../utils"; +import { BackupDecryptor } from "../common-crypto/CryptoBackend"; + +// The minimum time to wait between two retries in case of errors. To avoid hammering the server. +const KEY_BACKUP_BACKOFF = 5000; // ms + +/** + * Enumerates the different kind of errors that can occurs when downloading and importing a key from backup. + */ +enum KeyDownloadErrorCode { + /** The requested key is not in the backup. */ + MISSING_DECRYPTION_KEY = "MISSING_DECRYPTION_KEY", + /** A network error occurred while trying to download the key from backup. */ + NETWORK_ERROR = "NETWORK_ERROR", + /** The loop has been stopped. */ + STOPPED = "STOPPED", +} + +class KeyDownloadError extends Error { + public constructor(public readonly code: KeyDownloadErrorCode) { + super(`Failed to get key from backup: ${code}`); + this.name = "KeyDownloadError"; + } +} + +class KeyDownloadRateLimitError extends Error { + public constructor(public readonly retryMillis: number) { + super(`Failed to get key from backup: rate limited`); + this.name = "KeyDownloadRateLimitError"; + } +} + +/** Details of a megolm session whose key we are trying to fetch. */ +type SessionInfo = { roomId: string; megolmSessionId: string }; + +/** Holds the current backup decryptor and version that should be used. */ +type Configuration = { + backupVersion: string; + decryptor: BackupDecryptor; +}; + +/** + * Used when an 'unable to decrypt' error occurs. It attempts to download the key from the backup. + * + * The current backup API lacks pagination, which can lead to lengthy key retrieval times for large histories (several 10s of minutes). + * To mitigate this, keys are downloaded on demand as decryption errors occurs. + * While this approach may result in numerous requests, it improves user experience by reducing wait times for message decryption. + * + * The PerSessionKeyBackupDownloader is resistant to backup configuration changes: it will automatically resume querying when + * the backup is configured correctly. + */ +export class PerSessionKeyBackupDownloader { + private stopped = false; + + /** The version and decryption key to use with current backup if all set up correctly */ + private configuration: Configuration | null = null; + + /** We remember when a session was requested and not found in backup to avoid query again too soon. + * Map of session_id to timestamp */ + private sessionLastCheckAttemptedTime: Map = new Map(); + + /** The logger to use */ + private readonly logger: Logger; + + /** Whether the download loop is running. */ + private downloadLoopRunning = false; + + /** The list of requests that are queued. */ + private queuedRequests: SessionInfo[] = []; + + /** Remembers if we have a configuration problem. */ + private hasConfigurationProblem = false; + + /** The current server backup version check promise. To avoid doing a server call if one is in flight. */ + private currentBackupVersionCheck: Promise | null = null; + + /** + * Creates a new instance of PerSessionKeyBackupDownloader. + * + * @param backupManager - The backup manager to use. + * @param olmMachine - The olm machine to use. + * @param http - The http instance to use. + * @param logger - The logger to use. + */ + public constructor( + logger: Logger, + private readonly olmMachine: OlmMachine, + private readonly http: MatrixHttpApi, + private readonly backupManager: RustBackupManager, + ) { + this.logger = logger.getChild("[PerSessionKeyBackupDownloader]"); + + backupManager.on(CryptoEvent.KeyBackupStatus, this.onBackupStatusChanged); + backupManager.on(CryptoEvent.KeyBackupFailed, this.onBackupStatusChanged); + backupManager.on(CryptoEvent.KeyBackupDecryptionKeyCached, this.onBackupStatusChanged); + } + + /** + * Called when a MissingRoomKey or UnknownMessageIndex decryption error is encountered. + * + * This will try to download the key from the backup if there is a trusted active backup. + * In case of success the key will be imported and the onRoomKeysUpdated callback will be called + * internally by the rust-sdk and decryption will be retried. + * + * @param roomId - The room ID of the room where the error occurred. + * @param megolmSessionId - The megolm session ID that is missing. + */ + public onDecryptionKeyMissingError(roomId: string, megolmSessionId: string): void { + // Several messages encrypted with the same session may be decrypted at the same time, + // so we need to be resistant and not query several time the same session. + if (this.isAlreadyInQueue(roomId, megolmSessionId)) { + // There is already a request queued for this session, no need to queue another one. + this.logger.trace(`Not checking key backup for session ${megolmSessionId} as it is already queued`); + return; + } + + if (this.wasRequestedRecently(megolmSessionId)) { + // We already tried to download this session recently and it was not in backup, no need to try again. + this.logger.trace( + `Not checking key backup for session ${megolmSessionId} as it was already requested recently`, + ); + return; + } + + // We always add the request to the queue, even if we have a configuration problem (can't access backup). + // This is to make sure that if the configuration problem is resolved, we will try to download the key. + // This will happen after an initial sync, at this point the backup will not yet be trusted and the decryption + // key will not be available, but it will be just after the verification. + // We don't need to persist it because currently on refresh the sdk will retry to decrypt the messages in error. + this.queuedRequests.push({ roomId, megolmSessionId }); + + // Start the download loop if it's not already running. + this.downloadKeysLoop(); + } + + public stop(): void { + this.stopped = true; + this.backupManager.off(CryptoEvent.KeyBackupStatus, this.onBackupStatusChanged); + this.backupManager.off(CryptoEvent.KeyBackupFailed, this.onBackupStatusChanged); + this.backupManager.off(CryptoEvent.KeyBackupDecryptionKeyCached, this.onBackupStatusChanged); + } + + /** + * Called when the backup status changes (CryptoEvents) + * This will trigger a check of the backup configuration. + */ + private onBackupStatusChanged = (): void => { + // we want to force check configuration, so we clear the current one. + this.hasConfigurationProblem = false; + this.configuration = null; + this.getOrCreateBackupConfiguration().then((configuration) => { + if (configuration) { + // restart the download loop if it was stopped + this.downloadKeysLoop(); + } + }); + }; + + /** Returns true if the megolm session is already queued for download. */ + private isAlreadyInQueue(roomId: string, megolmSessionId: string): boolean { + return this.queuedRequests.some((info) => { + return info.roomId == roomId && info.megolmSessionId == megolmSessionId; + }); + } + + /** + * Marks the session as not found in backup, to avoid retrying to soon for a key not in backup + * + * @param megolmSessionId - The megolm session ID that is missing. + */ + private markAsNotFoundInBackup(megolmSessionId: string): void { + const now = Date.now(); + this.sessionLastCheckAttemptedTime.set(megolmSessionId, now); + // if too big make some cleaning to keep under control + if (this.sessionLastCheckAttemptedTime.size > 100) { + this.sessionLastCheckAttemptedTime = new Map( + Array.from(this.sessionLastCheckAttemptedTime).filter((sid, ts) => { + return Math.max(now - ts, 0) < KEY_BACKUP_BACKOFF; + }), + ); + } + } + + /** Returns true if the session was requested recently. */ + private wasRequestedRecently(megolmSessionId: string): boolean { + const lastCheck = this.sessionLastCheckAttemptedTime.get(megolmSessionId); + if (!lastCheck) return false; + return Math.max(Date.now() - lastCheck, 0) < KEY_BACKUP_BACKOFF; + } + + private async getBackupDecryptionKey(): Promise { + try { + return await this.olmMachine.getBackupKeys(); + } catch (e) { + return null; + } + } + + /** + * Requests a key from the server side backup. + * + * @param version - The backup version to use. + * @param roomId - The room ID of the room where the error occurred. + * @param sessionId - The megolm session ID that is missing. + */ + private async requestRoomKeyFromBackup( + version: string, + roomId: string, + sessionId: string, + ): Promise { + const path = encodeUri("/room_keys/keys/$roomId/$sessionId", { + $roomId: roomId, + $sessionId: sessionId, + }); + + return await this.http.authedRequest(Method.Get, path, { version }, undefined, { + prefix: ClientPrefix.V3, + }); + } + + private async downloadKeysLoop(): Promise { + if (this.downloadLoopRunning) return; + + // If we have a configuration problem, we don't want to try to download. + // If any configuration change is detected, we will retry and restart the loop. + if (this.hasConfigurationProblem) return; + + this.downloadLoopRunning = true; + + try { + while (this.queuedRequests.length > 0) { + // we just peek the first one without removing it, so if a new request for same key comes in while we're + // processing this one, it won't queue another request. + const request = this.queuedRequests[0]; + try { + // The backup could have changed between the time we queued the request and now, so we need to check + const configuration = await this.getOrCreateBackupConfiguration(); + if (!configuration) { + // Backup is not configured correctly, so stop the loop. + this.downloadLoopRunning = false; + return; + } + + const result = await this.queryKeyBackup(request.roomId, request.megolmSessionId, configuration); + + if (this.stopped) { + return; + } + // We got the encrypted key from backup, let's try to decrypt and import it. + try { + await this.decryptAndImport(request, result, configuration); + } catch (e) { + this.logger.error( + `Error while decrypting and importing key backup for session ${request.megolmSessionId}`, + e, + ); + } + // now remove the request from the queue as we've processed it. + this.queuedRequests.shift(); + } catch (err) { + if (err instanceof KeyDownloadError) { + switch (err.code) { + case KeyDownloadErrorCode.MISSING_DECRYPTION_KEY: + this.markAsNotFoundInBackup(request.megolmSessionId); + // continue for next one + this.queuedRequests.shift(); + break; + case KeyDownloadErrorCode.NETWORK_ERROR: + // We don't want to hammer if there is a problem, so wait a bit. + await sleep(KEY_BACKUP_BACKOFF); + break; + case KeyDownloadErrorCode.STOPPED: + // If the downloader was stopped, we don't want to retry. + this.downloadLoopRunning = false; + return; + } + } else if (err instanceof KeyDownloadRateLimitError) { + // we want to retry after the backoff time + await sleep(err.retryMillis); + } + } + } + } finally { + // all pending request have been processed, we can stop the loop. + this.downloadLoopRunning = false; + } + } + + /** + * Query the backup for a key. + * + * @param targetRoomId - ID of the room that the session is used in. + * @param targetSessionId - ID of the session for which to check backup. + * @param configuration - The backup configuration to use. + */ + private async queryKeyBackup( + targetRoomId: string, + targetSessionId: string, + configuration: Configuration, + ): Promise { + this.logger.debug(`Checking key backup for session ${targetSessionId}`); + if (this.stopped) throw new KeyDownloadError(KeyDownloadErrorCode.STOPPED); + try { + const res = await this.requestRoomKeyFromBackup(configuration.backupVersion, targetRoomId, targetSessionId); + this.logger.debug(`Got key from backup for sessionId:${targetSessionId}`); + return res; + } catch (e) { + if (this.stopped) throw new KeyDownloadError(KeyDownloadErrorCode.STOPPED); + + this.logger.info(`No luck requesting key backup for session ${targetSessionId}: ${e}`); + if (e instanceof MatrixError) { + const errCode = e.data.errcode; + if (errCode == "M_NOT_FOUND") { + // Unfortunately the spec doesn't give us a way to differentiate between a missing key and a wrong version. + // Synapse will return: + // - "error": "Unknown backup version" if the version is wrong. + // - "error": "No room_keys found" if the key is missing. + // It's useful to know if the key is missing or if the version is wrong. + // As it's not spec'ed, we fall back on considering the key is not in backup. + // Notice that this request will be lost if instead the backup got out of sync (updated from other session). + throw new KeyDownloadError(KeyDownloadErrorCode.MISSING_DECRYPTION_KEY); + } + if (errCode == "M_LIMIT_EXCEEDED") { + const waitTime = e.data.retry_after_ms; + if (waitTime > 0) { + this.logger.info(`Rate limited by server, waiting ${waitTime}ms`); + throw new KeyDownloadRateLimitError(waitTime); + } else { + // apply the default backoff time + throw new KeyDownloadRateLimitError(KEY_BACKUP_BACKOFF); + } + } + } + throw new KeyDownloadError(KeyDownloadErrorCode.NETWORK_ERROR); + } + } + + private async decryptAndImport( + sessionInfo: SessionInfo, + data: KeyBackupSession, + configuration: Configuration, + ): Promise { + const sessionsToImport: Record = { [sessionInfo.megolmSessionId]: data }; + + const keys = await configuration!.decryptor.decryptSessions(sessionsToImport); + for (const k of keys) { + k.room_id = sessionInfo.roomId; + } + await this.backupManager.importBackedUpRoomKeys(keys); + } + + /** + * Gets the current backup configuration or create one if it doesn't exist. + * + * When a valid configuration is found it is cached and returned for subsequent calls. + * Otherwise, if a check is forced or a check has not yet been done, a new check is done. + * + * @returns The backup configuration to use or null if there is a configuration problem. + */ + private async getOrCreateBackupConfiguration(): Promise { + if (this.configuration) { + return this.configuration; + } + + // We already tried to check the configuration and it failed. + // We don't want to try again immediately, we will retry if a configuration change is detected. + if (this.hasConfigurationProblem) { + return null; + } + + // This method can be called rapidly by several emitted CryptoEvent, so we need to make sure that we don't + // query the server several times. + if (this.currentBackupVersionCheck != null) { + this.logger.debug(`Already checking server version, use current promise`); + return await this.currentBackupVersionCheck; + } + + this.currentBackupVersionCheck = this.internalCheckFromServer(); + try { + return await this.currentBackupVersionCheck; + } finally { + this.currentBackupVersionCheck = null; + } + } + + private async internalCheckFromServer(): Promise { + let currentServerVersion = null; + try { + currentServerVersion = await this.backupManager.requestKeyBackupVersion(); + } catch (e) { + this.logger.debug(`Backup: error while checking server version: ${e}`); + this.hasConfigurationProblem = true; + return null; + } + this.logger.debug(`Got current backup version from server: ${currentServerVersion?.version}`); + + if (currentServerVersion?.algorithm != "m.megolm_backup.v1.curve25519-aes-sha2") { + this.logger.info(`Unsupported algorithm ${currentServerVersion?.algorithm}`); + this.hasConfigurationProblem = true; + return null; + } + + if (!currentServerVersion?.version) { + this.logger.info(`No current key backup`); + this.hasConfigurationProblem = true; + return null; + } + + const activeVersion = await this.backupManager.getActiveBackupVersion(); + if (activeVersion == null || currentServerVersion.version != activeVersion) { + // Either the current backup version on server side is not trusted, or it is out of sync with the active version on the client side. + this.logger.info( + `The current backup version on the server (${currentServerVersion.version}) is not trusted. Version we are currently backing up to: ${activeVersion}`, + ); + this.hasConfigurationProblem = true; + return null; + } + + const authData = currentServerVersion.auth_data as Curve25519AuthData; + + const backupKeys = await this.getBackupDecryptionKey(); + if (!backupKeys?.decryptionKey) { + this.logger.debug(`Not checking key backup for session (no decryption key)`); + this.hasConfigurationProblem = true; + return null; + } + + if (activeVersion != backupKeys.backupVersion) { + this.logger.debug( + `Version for which we have a decryption key (${backupKeys.backupVersion}) doesn't match the version we are backing up to (${activeVersion})`, + ); + this.hasConfigurationProblem = true; + return null; + } + + if (authData.public_key != backupKeys.decryptionKey.megolmV1PublicKey.publicKeyBase64) { + this.logger.debug(`getBackupDecryptor key mismatch error`); + this.hasConfigurationProblem = true; + return null; + } + + const backupDecryptor = this.backupManager.createBackupDecryptor(backupKeys.decryptionKey); + this.hasConfigurationProblem = false; + this.configuration = { + decryptor: backupDecryptor, + backupVersion: activeVersion, + }; + return this.configuration; + } +} diff --git a/src/rust-crypto/backup.ts b/src/rust-crypto/backup.ts index f8c23e19dfa..1b3f8891040 100644 --- a/src/rust-crypto/backup.ts +++ b/src/rust-crypto/backup.ts @@ -34,6 +34,7 @@ import { OutgoingRequestProcessor } from "./OutgoingRequestProcessor"; import { sleep } from "../utils"; import { BackupDecryptor } from "../common-crypto/CryptoBackend"; import { IEncryptedPayload } from "../crypto/aes"; +import { ImportRoomKeyProgressData, ImportRoomKeysOpts } from "../crypto-api"; /** Authentification of the backup info, depends on algorithm */ type AuthData = KeyBackupInfo["auth_data"]; @@ -173,6 +174,49 @@ export class RustBackupManager extends TypedEventEmitter { + const jsonKeys = JSON.stringify(keys); + await this.olmMachine.importExportedRoomKeys(jsonKeys, (progress: BigInt, total: BigInt): void => { + const importOpt: ImportRoomKeyProgressData = { + total: Number(total), + successes: Number(progress), + stage: "load_keys", + failures: 0, + }; + opts?.progressCallback?.(importOpt); + }); + } + + /** + * Implementation of {@link CryptoBackend#importBackedUpRoomKeys}. + */ + public async importBackedUpRoomKeys(keys: IMegolmSessionData[], opts?: ImportRoomKeysOpts): Promise { + const keysByRoom: Map> = new Map(); + for (const key of keys) { + const roomId = new RustSdkCryptoJs.RoomId(key.room_id); + if (!keysByRoom.has(roomId)) { + keysByRoom.set(roomId, new Map()); + } + keysByRoom.get(roomId)!.set(key.session_id, key); + } + await this.olmMachine.importBackedUpRoomKeys(keysByRoom, (progress: BigInt, total: BigInt): void => { + const importOpt: ImportRoomKeyProgressData = { + total: Number(total), + successes: Number(progress), + stage: "load_keys", + failures: 0, + }; + opts?.progressCallback?.(importOpt); + }); + } + private keyBackupCheckInProgress: Promise | null = null; /** Helper for `checkKeyBackup` */ @@ -348,7 +392,7 @@ export class RustBackupManager extends TypedEventEmitter { + public async requestKeyBackupVersion(): Promise { try { return await this.http.authedRequest( Method.Get, @@ -440,6 +484,14 @@ export class RustBackupManager extends TypedEventEmitter = {}; // When did we last try to check the server for a given session id? + private readonly perSessionBackupDownloader: PerSessionKeyBackupDownloader; private readonly reemitter = new TypedReEmitter(this); @@ -143,9 +140,18 @@ export class RustCrypto extends TypedEventEmitter KEY_BACKUP_CHECK_RATE_LIMIT) { - this.sessionLastCheckAttemptedTime[targetSessionId!] = now; - this.queryKeyBackup(targetRoomId, targetSessionId).catch((e) => { - this.logger.error(`Unhandled error while checking key backup for session ${targetSessionId}`, e); - }); - } else { - const lastCheckStr = new Date(lastCheck).toISOString(); - this.logger.debug( - `Not checking key backup for session ${targetSessionId} (last checked at ${lastCheckStr})`, - ); - } - } - - /** - * Helper for {@link RustCrypto#startQueryKeyBackupRateLimited}. - * - * Requests the backup and imports it. Doesn't do any rate-limiting. - * - * @param targetRoomId - ID of the room that the session is used in. - * @param targetSessionId - ID of the session for which to check backup. - */ - private async queryKeyBackup(targetRoomId: string, targetSessionId: string): Promise { - const backupKeys: RustSdkCryptoJs.BackupKeys = await this.olmMachine.getBackupKeys(); - if (!backupKeys.decryptionKey) { - this.logger.debug(`Not checking key backup for session ${targetSessionId} (no decryption key)`); - return; - } - - this.logger.debug(`Checking key backup for session ${targetSessionId}`); - - const version = backupKeys.backupVersion; - const path = encodeUri("/room_keys/keys/$roomId/$sessionId", { - $roomId: targetRoomId, - $sessionId: targetSessionId, - }); - - let res: KeyBackupSession; - try { - res = await this.http.authedRequest(Method.Get, path, { version }, undefined, { - prefix: ClientPrefix.V3, - }); - } catch (e) { - this.logger.info(`No luck requesting key backup for session ${targetSessionId}: ${e}`); - return; - } - - if (this.stopped) return; - - const backupDecryptor = new RustBackupDecryptor(backupKeys.decryptionKey); - const sessionsToImport: Record = { [targetSessionId]: res }; - const keys = await backupDecryptor.decryptSessions(sessionsToImport); - for (const k of keys) { - k.room_id = targetRoomId; - } - await this.importBackedUpRoomKeys(keys); - } - /** * Return the OlmMachine only if {@link RustCrypto#stop} has not been called. * @@ -268,6 +205,7 @@ export class RustCrypto extends TypedEventEmitter { - const jsonKeys = JSON.stringify(keys); - await this.olmMachine.importExportedRoomKeys(jsonKeys, (progress: BigInt, total: BigInt): void => { - const importOpt: ImportRoomKeyProgressData = { - total: Number(total), - successes: Number(progress), - stage: "load_keys", - failures: 0, - }; - opts?.progressCallback?.(importOpt); - }); + return await this.backupManager.importRoomKeys(keys, opts); } /** @@ -1261,30 +1190,14 @@ export class RustCrypto extends TypedEventEmitter { - const keysByRoom: Map> = new Map(); - for (const key of keys) { - const roomId = new RustSdkCryptoJs.RoomId(key.room_id); - if (!keysByRoom.has(roomId)) { - keysByRoom.set(roomId, new Map()); - } - keysByRoom.get(roomId)!.set(key.session_id, key); - } - await this.olmMachine.importBackedUpRoomKeys(keysByRoom, (progress: BigInt, total: BigInt): void => { - const importOpt: ImportRoomKeyProgressData = { - total: Number(total), - successes: Number(progress), - stage: "load_keys", - failures: 0, - }; - opts?.progressCallback?.(importOpt); - }); + return await this.backupManager.importBackedUpRoomKeys(keys, opts); } /////////////////////////////////////////////////////////////////////////////////////////////////////////////////// @@ -1683,7 +1596,7 @@ class EventDecryptor { public constructor( private readonly logger: Logger, private readonly olmMachine: RustSdkCryptoJs.OlmMachine, - private readonly crypto: RustCrypto, + private readonly perSessionBackupDownloader: PerSessionKeyBackupDownloader, ) {} public async attemptEventDecryption(event: MatrixEvent): Promise { @@ -1724,7 +1637,7 @@ class EventDecryptor { session: content.sender_key + "|" + content.session_id, }, ); - this.crypto.startQueryKeyBackupRateLimited( + this.perSessionBackupDownloader.onDecryptionKeyMissingError( event.getRoomId()!, event.getWireContent().session_id!, ); @@ -1738,7 +1651,7 @@ class EventDecryptor { session: content.sender_key + "|" + content.session_id, }, ); - this.crypto.startQueryKeyBackupRateLimited( + this.perSessionBackupDownloader.onDecryptionKeyMissingError( event.getRoomId()!, event.getWireContent().session_id!, );