Skip to content

Commit

Permalink
fix(asyncio): already cancelled tasks ends up in 'InvalidStateError: …
Browse files Browse the repository at this point in the history
…invalid state' (#2593)
  • Loading branch information
mxschmitt authored Oct 11, 2024
1 parent a71a0ce commit 4d31bdc
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 1 deletion.
2 changes: 2 additions & 0 deletions playwright/_impl/_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,8 @@ def cleanup(self, cause: str = None) -> None:
# To prevent 'Future exception was never retrieved' we ignore all callbacks that are no_reply.
if callback.no_reply:
continue
if callback.future.cancelled():
continue
callback.future.set_exception(self._closed_error)
self._callbacks.clear()
self.emit("close")
Expand Down
22 changes: 21 additions & 1 deletion tests/async/test_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
# limitations under the License.
import asyncio
import gc
import sys
from typing import Dict

import pytest

from playwright.async_api import async_playwright
from playwright.async_api import Page, async_playwright
from tests.server import Server
from tests.utils import TARGET_CLOSED_ERROR_MESSAGE

Expand Down Expand Up @@ -67,3 +68,22 @@ async def test_cancel_pending_protocol_call_on_playwright_stop(server: Server) -
with pytest.raises(Exception) as exc_info:
await pending_task
assert TARGET_CLOSED_ERROR_MESSAGE in str(exc_info.value)


async def test_should_not_throw_with_taskgroup(page: Page) -> None:
if sys.version_info < (3, 11):
pytest.skip("TaskGroup is only available in Python 3.11+")

from builtins import ExceptionGroup # type: ignore

async def raise_exception() -> None:
raise ValueError("Something went wrong")

with pytest.raises(ExceptionGroup) as exc_info:
async with asyncio.TaskGroup() as group: # type: ignore
group.create_task(page.locator(".this-element-does-not-exist").inner_text())
group.create_task(raise_exception())
assert len(exc_info.value.exceptions) == 1
assert "Something went wrong" in str(exc_info.value.exceptions[0])
assert isinstance(exc_info.value.exceptions[0], ValueError)
assert await page.evaluate("() => 11 * 11") == 121

0 comments on commit 4d31bdc

Please sign in to comment.