@@ -105,6 +105,9 @@ async def filter_events_for_client(
105
105
The filtered events. The `unsigned` data is annotated with the membership state
106
106
of `user_id` at each event.
107
107
"""
108
+ if not events :
109
+ return []
110
+
108
111
# Filter out events that have been soft failed so that we don't relay them
109
112
# to clients.
110
113
events_before_filtering = events
@@ -117,13 +120,38 @@ async def filter_events_for_client(
117
120
[event .event_id for event in events ],
118
121
)
119
122
120
- types = (_HISTORY_VIS_KEY , (EventTypes .Member , user_id ))
123
+ types = (
124
+ _HISTORY_VIS_KEY ,
125
+ (EventTypes .Member , user_id ),
126
+ )
127
+
128
+ room_id = events [0 ].room_id
129
+ assert all (event .room_id == room_id for event in events )
130
+
131
+ visibilities : Dict [str , str ] = {}
132
+ memberships : Dict [str , Optional [EventBase ]] = {}
133
+ events_to_fetch = {e .event_id for e in events if not e .internal_metadata .outlier }
134
+ if not is_peeking :
135
+ fetched_visibilities = await storage .main .get_visibility_for_events (
136
+ room_id , [e for e in events if not e .internal_metadata .outlier ]
137
+ )
138
+ for event_id , visibility in fetched_visibilities .items ():
139
+ if visibility in (
140
+ HistoryVisibility .SHARED ,
141
+ HistoryVisibility .WORLD_READABLE ,
142
+ ):
143
+ events_to_fetch .discard (event_id )
144
+ visibilities [event_id ] = visibility
121
145
122
146
# we exclude outliers at this point, and then handle them separately later
123
- event_id_to_state = await storage .state .get_state_for_events (
124
- frozenset (e .event_id for e in events if not e .internal_metadata .outlier ),
125
- state_filter = StateFilter .from_types (types ),
126
- )
147
+ if events_to_fetch :
148
+ event_id_to_state = await storage .state .get_state_for_events (
149
+ events_to_fetch ,
150
+ state_filter = StateFilter .from_types (types ),
151
+ )
152
+ for event_id , state in event_id_to_state .items ():
153
+ visibilities [event_id ] = get_effective_room_visibility_from_state (state )
154
+ memberships [event_id ] = state .get ((EventTypes .Member , user_id ))
127
155
128
156
# Get the users who are ignored by the requesting user.
129
157
ignore_list = await storage .main .ignored_users (user_id )
@@ -140,18 +168,19 @@ async def filter_events_for_client(
140
168
] = await storage .main .get_retention_policy_for_room (room_id )
141
169
142
170
def allowed (event : EventBase ) -> Optional [EventBase ]:
143
- state_after_event = event_id_to_state .get (event .event_id )
144
- filtered = _check_client_allowed_to_see_event (
171
+ # state_after_event = event_id_to_state.get(event.event_id)
172
+ filtered = _check_client_allowed_to_see_event_with_state (
145
173
user_id = user_id ,
146
174
event = event ,
147
175
clock = storage .main .clock ,
148
176
filter_send_to_client = filter_send_to_client ,
149
177
sender_ignored = event .sender in ignore_list ,
150
178
always_include_ids = always_include_ids ,
151
179
retention_policy = retention_policies [event .room_id ],
152
- state = state_after_event ,
153
180
is_peeking = is_peeking ,
154
181
sender_erased = erased_senders .get (event .sender , False ),
182
+ visibility = visibilities [event .event_id ],
183
+ membership_event = memberships .get (event .event_id ),
155
184
)
156
185
if filtered is None :
157
186
return None
@@ -165,11 +194,9 @@ def allowed(event: EventBase) -> Optional[EventBase]:
165
194
user_membership_event : Optional [EventBase ]
166
195
if event .type == EventTypes .Member and event .state_key == user_id :
167
196
user_membership_event = event
168
- elif state_after_event is not None :
169
- user_membership_event = state_after_event .get ((EventTypes .Member , user_id ))
170
197
else :
171
- # unreachable!
172
- raise Exception ( "Missing state for event that is not user's own membership" )
198
+ # TODO: Actually get the proper membership
199
+ user_membership_event = memberships . get ( event_id )
173
200
174
201
user_membership = (
175
202
user_membership_event .membership
@@ -353,6 +380,41 @@ def _check_client_allowed_to_see_event(
353
380
354
381
the original event if they can see it as normal.
355
382
"""
383
+
384
+ visibility = HistoryVisibility .SHARED
385
+
386
+ if state is not None :
387
+ visibility = get_effective_room_visibility_from_state (state )
388
+ membership_event = state .get ((EventTypes .Member , user_id )) if state else None
389
+
390
+ return _check_client_allowed_to_see_event_with_state (
391
+ user_id ,
392
+ event ,
393
+ clock ,
394
+ filter_send_to_client ,
395
+ is_peeking ,
396
+ always_include_ids ,
397
+ sender_ignored ,
398
+ retention_policy ,
399
+ sender_erased ,
400
+ visibility = visibility ,
401
+ membership_event = membership_event ,
402
+ )
403
+
404
+
405
+ def _check_client_allowed_to_see_event_with_state (
406
+ user_id : str ,
407
+ event : EventBase ,
408
+ clock : Clock ,
409
+ filter_send_to_client : bool ,
410
+ is_peeking : bool ,
411
+ always_include_ids : FrozenSet [str ],
412
+ sender_ignored : bool ,
413
+ retention_policy : RetentionPolicy ,
414
+ sender_erased : bool ,
415
+ visibility : str ,
416
+ membership_event : Optional [EventBase ],
417
+ ) -> Optional [EventBase ]:
356
418
# Only run some checks if these events aren't about to be sent to clients. This is
357
419
# because, if this is not the case, we're probably only checking if the users can
358
420
# see events in the room at that point in the DAG, and that shouldn't be decided
@@ -390,12 +452,6 @@ def _check_client_allowed_to_see_event(
390
452
)
391
453
return None
392
454
393
- if state is None :
394
- raise Exception ("Missing state for non-outlier event" )
395
-
396
- # get the room_visibility at the time of the event.
397
- visibility = get_effective_room_visibility_from_state (state )
398
-
399
455
# Check if the room has lax history visibility, allowing us to skip
400
456
# membership checks.
401
457
#
@@ -408,6 +464,10 @@ def _check_client_allowed_to_see_event(
408
464
):
409
465
return event
410
466
467
+ if membership_event :
468
+ state = {(EventTypes .Member , user_id ): membership_event }
469
+ else :
470
+ state = {}
411
471
membership_result = _check_membership (user_id , event , visibility , state , is_peeking )
412
472
if not membership_result .allowed :
413
473
filtered_event_logger .debug (
0 commit comments