Skip to content

Commit fb751d3

Browse files
committed
Use fast path
1 parent 35d797a commit fb751d3

File tree

2 files changed

+112
-19
lines changed

2 files changed

+112
-19
lines changed

synapse/storage/databases/main/sliding_sync.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,12 @@
1414

1515

1616
import logging
17-
from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Set, cast
17+
from typing import TYPE_CHECKING, Collection, Dict, List, Mapping, Optional, Set, cast
1818

1919
import attr
2020

2121
from synapse.api.errors import SlidingSyncUnknownPosition
22+
from synapse.events import EventBase
2223
from synapse.logging.opentracing import log_kv
2324
from synapse.storage._base import SQLBaseStore, db_to_json
2425
from synapse.storage.database import LoggingTransaction
@@ -451,6 +452,38 @@ def _get_and_clear_connection_positions_txn(
451452
room_configs=room_configs,
452453
)
453454

455+
async def get_visibility_for_events(
456+
self, room_id: str, events: Collection[EventBase]
457+
) -> Mapping[str, Optional[str]]:
458+
def get_visibility_for_events_txn(
459+
txn: LoggingTransaction,
460+
) -> Mapping[str, Optional[str]]:
461+
sql = """
462+
SELECT visibility FROM history_visibility_ranges
463+
WHERE start_range <= ? AND (? < end_range OR end_range IS NULL)
464+
AND room_id = ?
465+
"""
466+
467+
results = {}
468+
for event in events:
469+
txn.execute(
470+
sql,
471+
(
472+
event.internal_metadata.stream_ordering,
473+
event.internal_metadata.stream_ordering,
474+
room_id,
475+
),
476+
)
477+
row = txn.fetchone()
478+
if row is not None:
479+
results[event.event_id] = row[0]
480+
481+
return results
482+
483+
return await self.db_pool.runInteraction(
484+
"get_visibility_for_events", get_visibility_for_events_txn
485+
)
486+
454487

455488
@attr.s(auto_attribs=True, frozen=True)
456489
class PerConnectionStateDB:

synapse/visibility.py

Lines changed: 78 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,9 @@ async def filter_events_for_client(
105105
The filtered events. The `unsigned` data is annotated with the membership state
106106
of `user_id` at each event.
107107
"""
108+
if not events:
109+
return []
110+
108111
# Filter out events that have been soft failed so that we don't relay them
109112
# to clients.
110113
events_before_filtering = events
@@ -117,13 +120,38 @@ async def filter_events_for_client(
117120
[event.event_id for event in events],
118121
)
119122

120-
types = (_HISTORY_VIS_KEY, (EventTypes.Member, user_id))
123+
types = (
124+
_HISTORY_VIS_KEY,
125+
(EventTypes.Member, user_id),
126+
)
127+
128+
room_id = events[0].room_id
129+
assert all(event.room_id == room_id for event in events)
130+
131+
visibilities: Dict[str, str] = {}
132+
memberships: Dict[str, Optional[EventBase]] = {}
133+
events_to_fetch = {e.event_id for e in events if not e.internal_metadata.outlier}
134+
if not is_peeking:
135+
fetched_visibilities = await storage.main.get_visibility_for_events(
136+
room_id, [e for e in events if not e.internal_metadata.outlier]
137+
)
138+
for event_id, visibility in fetched_visibilities.items():
139+
if visibility in (
140+
HistoryVisibility.SHARED,
141+
HistoryVisibility.WORLD_READABLE,
142+
):
143+
events_to_fetch.discard(event_id)
144+
visibilities[event_id] = visibility
121145

122146
# we exclude outliers at this point, and then handle them separately later
123-
event_id_to_state = await storage.state.get_state_for_events(
124-
frozenset(e.event_id for e in events if not e.internal_metadata.outlier),
125-
state_filter=StateFilter.from_types(types),
126-
)
147+
if events_to_fetch:
148+
event_id_to_state = await storage.state.get_state_for_events(
149+
events_to_fetch,
150+
state_filter=StateFilter.from_types(types),
151+
)
152+
for event_id, state in event_id_to_state.items():
153+
visibilities[event_id] = get_effective_room_visibility_from_state(state)
154+
memberships[event_id] = state.get((EventTypes.Member, user_id))
127155

128156
# Get the users who are ignored by the requesting user.
129157
ignore_list = await storage.main.ignored_users(user_id)
@@ -140,18 +168,19 @@ async def filter_events_for_client(
140168
] = await storage.main.get_retention_policy_for_room(room_id)
141169

142170
def allowed(event: EventBase) -> Optional[EventBase]:
143-
state_after_event = event_id_to_state.get(event.event_id)
144-
filtered = _check_client_allowed_to_see_event(
171+
# state_after_event = event_id_to_state.get(event.event_id)
172+
filtered = _check_client_allowed_to_see_event_with_state(
145173
user_id=user_id,
146174
event=event,
147175
clock=storage.main.clock,
148176
filter_send_to_client=filter_send_to_client,
149177
sender_ignored=event.sender in ignore_list,
150178
always_include_ids=always_include_ids,
151179
retention_policy=retention_policies[event.room_id],
152-
state=state_after_event,
153180
is_peeking=is_peeking,
154181
sender_erased=erased_senders.get(event.sender, False),
182+
visibility=visibilities[event.event_id],
183+
membership_event=memberships.get(event.event_id),
155184
)
156185
if filtered is None:
157186
return None
@@ -165,11 +194,9 @@ def allowed(event: EventBase) -> Optional[EventBase]:
165194
user_membership_event: Optional[EventBase]
166195
if event.type == EventTypes.Member and event.state_key == user_id:
167196
user_membership_event = event
168-
elif state_after_event is not None:
169-
user_membership_event = state_after_event.get((EventTypes.Member, user_id))
170197
else:
171-
# unreachable!
172-
raise Exception("Missing state for event that is not user's own membership")
198+
# TODO: Actually get the proper membership
199+
user_membership_event = memberships.get(event_id)
173200

174201
user_membership = (
175202
user_membership_event.membership
@@ -353,6 +380,41 @@ def _check_client_allowed_to_see_event(
353380
354381
the original event if they can see it as normal.
355382
"""
383+
384+
visibility = HistoryVisibility.SHARED
385+
386+
if state is not None:
387+
visibility = get_effective_room_visibility_from_state(state)
388+
membership_event = state.get((EventTypes.Member, user_id)) if state else None
389+
390+
return _check_client_allowed_to_see_event_with_state(
391+
user_id,
392+
event,
393+
clock,
394+
filter_send_to_client,
395+
is_peeking,
396+
always_include_ids,
397+
sender_ignored,
398+
retention_policy,
399+
sender_erased,
400+
visibility=visibility,
401+
membership_event=membership_event,
402+
)
403+
404+
405+
def _check_client_allowed_to_see_event_with_state(
406+
user_id: str,
407+
event: EventBase,
408+
clock: Clock,
409+
filter_send_to_client: bool,
410+
is_peeking: bool,
411+
always_include_ids: FrozenSet[str],
412+
sender_ignored: bool,
413+
retention_policy: RetentionPolicy,
414+
sender_erased: bool,
415+
visibility: str,
416+
membership_event: Optional[EventBase],
417+
) -> Optional[EventBase]:
356418
# Only run some checks if these events aren't about to be sent to clients. This is
357419
# because, if this is not the case, we're probably only checking if the users can
358420
# see events in the room at that point in the DAG, and that shouldn't be decided
@@ -390,12 +452,6 @@ def _check_client_allowed_to_see_event(
390452
)
391453
return None
392454

393-
if state is None:
394-
raise Exception("Missing state for non-outlier event")
395-
396-
# get the room_visibility at the time of the event.
397-
visibility = get_effective_room_visibility_from_state(state)
398-
399455
# Check if the room has lax history visibility, allowing us to skip
400456
# membership checks.
401457
#
@@ -408,6 +464,10 @@ def _check_client_allowed_to_see_event(
408464
):
409465
return event
410466

467+
if membership_event:
468+
state = {(EventTypes.Member, user_id): membership_event}
469+
else:
470+
state = {}
411471
membership_result = _check_membership(user_id, event, visibility, state, is_peeking)
412472
if not membership_result.allowed:
413473
filtered_event_logger.debug(

0 commit comments

Comments
 (0)