Skip to content

Commit

Permalink
avoid refcycles in happy eyeballs
Browse files Browse the repository at this point in the history
  • Loading branch information
graingert committed Oct 13, 2024
1 parent 163f10c commit 41264d3
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 15 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ test = [
"""\
uvloop >= 0.21.0b1; platform_python_implementation == 'CPython' \
and platform_system != 'Windows'\
"""
""",
"ephemeral-port-reserve >= 1.1.4",
]
doc = [
"packaging",
Expand Down
31 changes: 17 additions & 14 deletions src/anyio/_core/_sockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,20 +211,23 @@ async def try_connect(remote_host: str, event: Event) -> None:
target_addrs = [(socket.AF_INET, addr_obj.compressed)]

oserrors: list[OSError] = []
async with create_task_group() as tg:
for i, (af, addr) in enumerate(target_addrs):
event = Event()
tg.start_soon(try_connect, addr, event)
with move_on_after(happy_eyeballs_delay):
await event.wait()

if connected_stream is None:
cause = (
oserrors[0]
if len(oserrors) == 1
else ExceptionGroup("multiple connection attempts failed", oserrors)
)
raise OSError("All connection attempts failed") from cause
try:
async with create_task_group() as tg:
for i, (af, addr) in enumerate(target_addrs):
event = Event()
tg.start_soon(try_connect, addr, event)
with move_on_after(happy_eyeballs_delay):
await event.wait()

if connected_stream is None:
cause = (
oserrors[0]
if len(oserrors) == 1
else ExceptionGroup("multiple connection attempts failed", oserrors)
)
raise OSError("All connection attempts failed") from cause
finally:
oserrors.clear()

if tls or tls_hostname or ssl_context:
try:
Expand Down
33 changes: 33 additions & 0 deletions tests/test_sockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from threading import Thread
from typing import Any, NoReturn, TypeVar, cast

import ephemeral_port_reserve
import psutil
import pytest
from _pytest.fixtures import SubRequest
Expand Down Expand Up @@ -125,6 +126,16 @@ def check_asyncio_bug(anyio_backend_name: str, family: AnyIPAddressFamily) -> No
pytest.skip("Does not work due to a known bug (39148)")


if sys.version_info <= (3, 11):

def no_other_refs() -> list[object]:
return [sys._getframe(1)]
else:

def no_other_refs() -> list[object]:
return []


_T = TypeVar("_T")


Expand Down Expand Up @@ -307,6 +318,28 @@ def serve() -> None:
server_sock.close()
assert client_addr[0] == expected_client_addr

@pytest.mark.skipif(
sys.implementation.name == "pypy",
reason=(
"gc.get_referrers is broken on PyPy see "
"https://github.com/pypy/pypy/issues/5075"
),
)
async def test_happy_eyeballs_refcycles(self) -> None:
"""
Test derived from https://github.com/python/cpython/pull/124859
"""
port = ephemeral_port_reserve.reserve()
exc = None
try:
async with await connect_tcp("localhost", port):
pass
except OSError as e:
exc = e.__cause__

assert isinstance(exc, OSError)
assert gc.get_referrers(exc) == no_other_refs()

@pytest.mark.parametrize(
"target, exception_class",
[
Expand Down

0 comments on commit 41264d3

Please sign in to comment.