diff --git a/changelog.d/17884.misc b/changelog.d/17884.misc new file mode 100644 index 00000000000..9dfa13f853c --- /dev/null +++ b/changelog.d/17884.misc @@ -0,0 +1 @@ +Minor speed-up of sliding sync by computing extensions results in parallel. diff --git a/synapse/handlers/sliding_sync/extensions.py b/synapse/handlers/sliding_sync/extensions.py index 0c77b525139..077887ec321 100644 --- a/synapse/handlers/sliding_sync/extensions.py +++ b/synapse/handlers/sliding_sync/extensions.py @@ -49,7 +49,10 @@ SlidingSyncConfig, SlidingSyncResult, ) -from synapse.util.async_helpers import concurrently_execute +from synapse.util.async_helpers import ( + concurrently_execute, + gather_optional_coroutines, +) if TYPE_CHECKING: from synapse.server import HomeServer @@ -97,26 +100,26 @@ async def get_extensions_response( if sync_config.extensions is None: return SlidingSyncResult.Extensions() - to_device_response = None + to_device_coro = None if sync_config.extensions.to_device is not None: - to_device_response = await self.get_to_device_extension_response( + to_device_coro = self.get_to_device_extension_response( sync_config=sync_config, to_device_request=sync_config.extensions.to_device, to_token=to_token, ) - e2ee_response = None + e2ee_coro = None if sync_config.extensions.e2ee is not None: - e2ee_response = await self.get_e2ee_extension_response( + e2ee_coro = self.get_e2ee_extension_response( sync_config=sync_config, e2ee_request=sync_config.extensions.e2ee, to_token=to_token, from_token=from_token, ) - account_data_response = None + account_data_coro = None if sync_config.extensions.account_data is not None: - account_data_response = await self.get_account_data_extension_response( + account_data_coro = self.get_account_data_extension_response( sync_config=sync_config, previous_connection_state=previous_connection_state, new_connection_state=new_connection_state, @@ -127,9 +130,9 @@ async def get_extensions_response( from_token=from_token, ) - receipts_response = None + receipts_coro = None if sync_config.extensions.receipts is not None: - receipts_response = await self.get_receipts_extension_response( + receipts_coro = self.get_receipts_extension_response( sync_config=sync_config, previous_connection_state=previous_connection_state, new_connection_state=new_connection_state, @@ -141,9 +144,9 @@ async def get_extensions_response( from_token=from_token, ) - typing_response = None + typing_coro = None if sync_config.extensions.typing is not None: - typing_response = await self.get_typing_extension_response( + typing_coro = self.get_typing_extension_response( sync_config=sync_config, actual_lists=actual_lists, actual_room_ids=actual_room_ids, @@ -153,6 +156,20 @@ async def get_extensions_response( from_token=from_token, ) + ( + to_device_response, + e2ee_response, + account_data_response, + receipts_response, + typing_response, + ) = await gather_optional_coroutines( + to_device_coro, + e2ee_coro, + account_data_coro, + receipts_coro, + typing_coro, + ) + return SlidingSyncResult.Extensions( to_device=to_device_response, e2ee=e2ee_response, diff --git a/synapse/logging/context.py b/synapse/logging/context.py index ae2b3d11c07..8a2dfeba13c 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -37,6 +37,7 @@ from types import TracebackType from typing import ( TYPE_CHECKING, + Any, Awaitable, Callable, Optional, @@ -850,6 +851,45 @@ def run_in_background( return d +def run_coroutine_in_background( + coroutine: typing.Coroutine[Any, Any, R], +) -> "defer.Deferred[R]": + """Run the coroutine, ensuring that the current context is restored after + return from the function, and that the sentinel context is set once the + deferred returned by the function completes. + + Useful for wrapping coroutines that you don't yield or await on (for + instance because you want to pass it to deferred.gatherResults()). + + This is a special case of `run_in_background` where we can accept a + coroutine directly rather than a function. We can do this because coroutines + do not run until called, and so calling an async function without awaiting + cannot change the log contexts. + """ + + current = current_context() + d = defer.ensureDeferred(coroutine) + + # The function may have reset the context before returning, so + # we need to restore it now. + ctx = set_current_context(current) + + # The original context will be restored when the deferred + # completes, but there is nothing waiting for it, so it will + # get leaked into the reactor or some other function which + # wasn't expecting it. We therefore need to reset the context + # here. + # + # (If this feels asymmetric, consider it this way: we are + # effectively forking a new thread of execution. We are + # probably currently within a ``with LoggingContext()`` block, + # which is supposed to have a single entry and exit point. But + # by spawning off another deferred, we are effectively + # adding a new exit point.) + d.addBoth(_set_context_cb, ctx) + return d + + T = TypeVar("T") diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 8618bb0651c..e1eb8a48632 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -51,7 +51,7 @@ ) import attr -from typing_extensions import Concatenate, Literal, ParamSpec +from typing_extensions import Concatenate, Literal, ParamSpec, Unpack from twisted.internet import defer from twisted.internet.defer import CancelledError @@ -61,6 +61,7 @@ from synapse.logging.context import ( PreserveLoggingContext, make_deferred_yieldable, + run_coroutine_in_background, run_in_background, ) from synapse.util import Clock @@ -344,6 +345,7 @@ async def yieldable_gather_results_delaying_cancellation( T2 = TypeVar("T2") T3 = TypeVar("T3") T4 = TypeVar("T4") +T5 = TypeVar("T5") @overload @@ -402,6 +404,112 @@ def gather_results( # type: ignore[misc] return deferred.addCallback(tuple) +@overload +async def gather_optional_coroutines( + *coroutines: Unpack[Tuple[Optional[Coroutine[Any, Any, T1]]]], +) -> Tuple[Optional[T1]]: ... + + +@overload +async def gather_optional_coroutines( + *coroutines: Unpack[ + Tuple[ + Optional[Coroutine[Any, Any, T1]], + Optional[Coroutine[Any, Any, T2]], + ] + ], +) -> Tuple[Optional[T1], Optional[T2]]: ... + + +@overload +async def gather_optional_coroutines( + *coroutines: Unpack[ + Tuple[ + Optional[Coroutine[Any, Any, T1]], + Optional[Coroutine[Any, Any, T2]], + Optional[Coroutine[Any, Any, T3]], + ] + ], +) -> Tuple[Optional[T1], Optional[T2], Optional[T3]]: ... + + +@overload +async def gather_optional_coroutines( + *coroutines: Unpack[ + Tuple[ + Optional[Coroutine[Any, Any, T1]], + Optional[Coroutine[Any, Any, T2]], + Optional[Coroutine[Any, Any, T3]], + Optional[Coroutine[Any, Any, T4]], + ] + ], +) -> Tuple[Optional[T1], Optional[T2], Optional[T3], Optional[T4]]: ... + + +@overload +async def gather_optional_coroutines( + *coroutines: Unpack[ + Tuple[ + Optional[Coroutine[Any, Any, T1]], + Optional[Coroutine[Any, Any, T2]], + Optional[Coroutine[Any, Any, T3]], + Optional[Coroutine[Any, Any, T4]], + Optional[Coroutine[Any, Any, T5]], + ] + ], +) -> Tuple[Optional[T1], Optional[T2], Optional[T3], Optional[T4], Optional[T5]]: ... + + +async def gather_optional_coroutines( + *coroutines: Unpack[Tuple[Optional[Coroutine[Any, Any, T1]], ...]], +) -> Tuple[Optional[T1], ...]: + """Helper function that allows waiting on multiple coroutines at once. + + The return value is a tuple of the return values of the coroutines in order. + + If a `None` is passed instead of a coroutine, it will be ignored and a None + is returned in the tuple. + + Note: For typechecking we need to have an explicit overload for each + distinct number of coroutines passed in. If you see type problems, it's + likely because you're using many arguments and you need to add a new + overload above. + """ + + try: + results = await make_deferred_yieldable( + defer.gatherResults( + [ + run_coroutine_in_background(coroutine) + for coroutine in coroutines + if coroutine is not None + ], + consumeErrors=True, + ) + ) + + results_iter = iter(results) + return tuple( + next(results_iter) if coroutine is not None else None + for coroutine in coroutines + ) + except defer.FirstError as dfe: + # unwrap the error from defer.gatherResults. + + # The raised exception's traceback only includes func() etc if + # the 'await' happens before the exception is thrown - ie if the failure + # happens *asynchronously* - otherwise Twisted throws away the traceback as it + # could be large. + # + # We could maybe reconstruct a fake traceback from Failure.frames. Or maybe + # we could throw Twisted into the fires of Mordor. + + # suppress exception chaining, because the FirstError doesn't tell us anything + # very interesting. + assert isinstance(dfe.subFailure.value, BaseException) + raise dfe.subFailure.value from None + + @attr.s(slots=True, auto_attribs=True) class _LinearizerEntry: # The number of things executing. diff --git a/tests/util/test_async_helpers.py b/tests/util/test_async_helpers.py index d82822d00dc..350a2b7c8cd 100644 --- a/tests/util/test_async_helpers.py +++ b/tests/util/test_async_helpers.py @@ -18,7 +18,7 @@ # # import traceback -from typing import Generator, List, NoReturn, Optional +from typing import Any, Coroutine, Generator, List, NoReturn, Optional, Tuple, TypeVar from parameterized import parameterized_class @@ -39,6 +39,7 @@ ObservableDeferred, concurrently_execute, delay_cancellation, + gather_optional_coroutines, stop_cancellation, timeout_deferred, ) @@ -46,6 +47,8 @@ from tests.server import get_clock from tests.unittest import TestCase +T = TypeVar("T") + class ObservableDeferredTest(TestCase): def test_succeed(self) -> None: @@ -588,3 +591,106 @@ def test_multiple_sleepers_wake(self) -> None: sleeper.wake("name") self.assertTrue(d1.called) self.assertTrue(d2.called) + + +class GatherCoroutineTests(TestCase): + """Tests for `gather_optional_coroutines`""" + + def make_coroutine(self) -> Tuple[Coroutine[Any, Any, T], "defer.Deferred[T]"]: + """Returns a coroutine and a deferred that it is waiting on to resolve""" + + d: "defer.Deferred[T]" = defer.Deferred() + + async def inner() -> T: + with PreserveLoggingContext(): + return await d + + return inner(), d + + def test_single(self) -> None: + "Test passing in a single coroutine works" + + with LoggingContext("test_ctx") as text_ctx: + deferred: "defer.Deferred[None]" + coroutine, deferred = self.make_coroutine() + + gather_deferred = defer.ensureDeferred( + gather_optional_coroutines(coroutine) + ) + + # We shouldn't have a result yet, and should be in the sentinel + # context. + self.assertNoResult(gather_deferred) + self.assertEqual(current_context(), SENTINEL_CONTEXT) + + # Resolving the deferred will resolve the coroutine + deferred.callback(None) + + # All coroutines have resolved, and so we should have the results + result = self.successResultOf(gather_deferred) + self.assertEqual(result, (None,)) + + # We should be back in the normal context. + self.assertEqual(current_context(), text_ctx) + + def test_multiple_resolve(self) -> None: + "Test passing in multiple coroutine that all resolve works" + + with LoggingContext("test_ctx") as test_ctx: + deferred1: "defer.Deferred[int]" + coroutine1, deferred1 = self.make_coroutine() + deferred2: "defer.Deferred[str]" + coroutine2, deferred2 = self.make_coroutine() + + gather_deferred = defer.ensureDeferred( + gather_optional_coroutines(coroutine1, coroutine2) + ) + + # We shouldn't have a result yet, and should be in the sentinel + # context. + self.assertNoResult(gather_deferred) + self.assertEqual(current_context(), SENTINEL_CONTEXT) + + # Even if we resolve one of the coroutines, we shouldn't have a result + # yet + deferred2.callback("test") + self.assertNoResult(gather_deferred) + self.assertEqual(current_context(), SENTINEL_CONTEXT) + + deferred1.callback(1) + + # All coroutines have resolved, and so we should have the results + result = self.successResultOf(gather_deferred) + self.assertEqual(result, (1, "test")) + + # We should be back in the normal context. + self.assertEqual(current_context(), test_ctx) + + def test_multiple_fail(self) -> None: + "Test passing in multiple coroutine where one fails does the right thing" + + with LoggingContext("test_ctx") as test_ctx: + deferred1: "defer.Deferred[int]" + coroutine1, deferred1 = self.make_coroutine() + deferred2: "defer.Deferred[str]" + coroutine2, deferred2 = self.make_coroutine() + + gather_deferred = defer.ensureDeferred( + gather_optional_coroutines(coroutine1, coroutine2) + ) + + # We shouldn't have a result yet, and should be in the sentinel + # context. + self.assertNoResult(gather_deferred) + self.assertEqual(current_context(), SENTINEL_CONTEXT) + + # Throw an exception in one of the coroutines + exc = Exception("test") + deferred2.errback(exc) + + # Expect the gather deferred to immediately fail + result_exc = self.failureResultOf(gather_deferred) + self.assertEqual(result_exc.value, exc) + + # We should be back in the normal context. + self.assertEqual(current_context(), test_ctx)