From ebfb01a65290754125d3f2b6bd9c43526133c95e Mon Sep 17 00:00:00 2001
From: Travis Ralston <travpc@gmail.com>
Date: Wed, 6 Jul 2022 07:50:34 +0200
Subject: [PATCH] Ensure appservice intents update their internal membership
 cache (#243)

* Ensure appservice intents update their internal membership cache

Fixes https://github.com/turt2live/matrix-bot-sdk/issues/113

* Fix appservice tests
---
 src/appservice/Appservice.ts      |  7 ++-
 test/appservice/AppserviceTest.ts | 95 +++++++++++++++++++++++++++++++
 2 files changed, 100 insertions(+), 2 deletions(-)

diff --git a/src/appservice/Appservice.ts b/src/appservice/Appservice.ts
index 6138d785..2dae4ad9 100644
--- a/src/appservice/Appservice.ts
+++ b/src/appservice/Appservice.ts
@@ -608,9 +608,12 @@ export class Appservice extends EventEmitter {
         return event;
     }
 
-    private processMembershipEvent(event: any): void {
+    private async processMembershipEvent(event: any): Promise<void> {
         if (!event["content"]) return;
 
+        // Update the target intent's joined rooms (fixes transition errors with the cache, like join->kick->join)
+        await this.getIntentForUserId(event['state_key']).refreshJoinedRooms();
+
         const targetMembership = event["content"]["membership"];
         if (targetMembership === "join") {
             this.emit("room.join", event["room_id"], event);
@@ -830,7 +833,7 @@ export class Appservice extends EventEmitter {
                     this.emit("room.message", event["room_id"], event);
                 }
                 if (event['type'] === 'm.room.member' && this.isNamespacedUser(event['state_key'])) {
-                    this.processMembershipEvent(event);
+                    await this.processMembershipEvent(event);
                 }
                 if (event['type'] === 'm.room.tombstone' && event['state_key'] === '') {
                     this.emit("room.archived", event['room_id'], event);
diff --git a/test/appservice/AppserviceTest.ts b/test/appservice/AppserviceTest.ts
index fd59eafe..fbcb4bc1 100644
--- a/test/appservice/AppserviceTest.ts
+++ b/test/appservice/AppserviceTest.ts
@@ -1429,6 +1429,9 @@ describe('Appservice', () => {
             return null;
         };
 
+        const intent = appservice.getIntentForSuffix("test");
+        intent.refreshJoinedRooms = () => Promise.resolve([]);
+
         await appservice.begin();
 
         try {
@@ -1532,6 +1535,98 @@ describe('Appservice', () => {
         }
     });
 
+    it('should refresh membership information of intents when actions are performed against them', async () => {
+        const port = await getPort();
+        const hsToken = "s3cret_token";
+        const appservice = new Appservice({
+            port: port,
+            bindAddress: '',
+            homeserverName: 'example.org',
+            homeserverUrl: 'https://localhost',
+            registration: {
+                as_token: "",
+                hs_token: hsToken,
+                sender_localpart: "_bot_",
+                namespaces: {
+                    users: [{ exclusive: true, regex: "@_prefix_.*:.+" }],
+                    rooms: [],
+                    aliases: [],
+                },
+            },
+        });
+        appservice.botIntent.ensureRegistered = () => {
+            return null;
+        };
+
+        await appservice.begin();
+
+        try {
+            const intent = appservice.getIntentForSuffix("test");
+            const refreshSpy = simple.stub().callFn(() => Promise.resolve([]));
+            intent.refreshJoinedRooms = refreshSpy;
+
+            // polyfill the dummy user too
+            const intent2 = appservice.getIntentForSuffix("test___WRONGUSER");
+            intent2.refreshJoinedRooms = () => Promise.resolve([]);
+
+            const joinTxn = {
+                events: [
+                    {
+                        type: "m.room.member",
+                        room_id: "!AAA:example.org",
+                        content: { membership: "join" },
+                        state_key: "@_prefix_test:example.org",
+                        sender: "@_prefix_test:example.org",
+                    },
+                    {
+                        type: "m.room.member",
+                        room_id: "!AAA:example.org",
+                        content: { membership: "join" },
+                        state_key: "@_prefix_test___WRONGUSER:example.org",
+                        sender: "@_prefix_test:example.org",
+                    },
+                ],
+            };
+            const kickTxn = {
+                events: [
+                    {
+                        type: "m.room.member",
+                        room_id: "!AAA:example.org",
+                        content: { membership: "leave" },
+                        state_key: "@_prefix_test:example.org",
+                        sender: "@someone_else:example.org",
+                    },
+                    {
+                        type: "m.room.member",
+                        room_id: "!AAA:example.org",
+                        content: { membership: "leave" },
+                        state_key: "@_prefix_test___WRONGUSER:example.org",
+                        sender: "@someone_else:example.org",
+                    },
+                ],
+            };
+
+            // eslint-disable-next-line no-inner-declarations
+            async function doCall(route: string, opts: any = {}) {
+                const res = await requestPromise({
+                    uri: `http://localhost:${port}${route}`,
+                    method: "PUT",
+                    qs: { access_token: hsToken },
+                    ...opts,
+                });
+                expect(res).toMatchObject({});
+
+                expect(refreshSpy.callCount).toBe(1);
+                refreshSpy.callCount = 0;
+            }
+
+            await doCall("/transactions/1", { json: joinTxn });
+            await doCall("/_matrix/app/v1/transactions/2", { json: kickTxn });
+        } finally {
+            appservice.stop();
+        }
+    });
+
     it('should handle room upgrade events in transactions', async () => {
         const port = await getPort();
         const hsToken = "s3cret_token";