Skip to content

Commit

Permalink
later
Browse files Browse the repository at this point in the history
Reviewed By: itamaro

Differential Revision: D64526256

fbshipit-source-id: 39354717bb2e2ad6ff403a487d4e84506cbc5341
  • Loading branch information
generatedunixname89002005287564 authored and facebook-github-bot committed Oct 20, 2024
1 parent 1c836d1 commit c150a41
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 42 deletions.
59 changes: 27 additions & 32 deletions later/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,27 +20,22 @@
import functools
import logging
import threading
from collections.abc import Awaitable, Callable, Coroutine, Hashable, Mapping, Sequence

from functools import partial, wraps
from inspect import isawaitable
from types import TracebackType
from typing import (
AbstractSet,
Any,
Awaitable,
Callable,
cast,
Coroutine,
Dict,
Hashable,
List,
Mapping,
NewType,
Optional,
overload,
ParamSpec,
Protocol,
Sequence,
Tuple,
Type,
TypeVar,
Expand Down Expand Up @@ -91,7 +86,7 @@ async def cancel(fut: asyncio.Future) -> None:
if fut.done():
return # nothing to do
fut.cancel()
exc: Optional[asyncio.CancelledError] = None
exc: asyncio.CancelledError | None = None
while not fut.done():
shielded = asyncio.shield(fut)
try:
Expand Down Expand Up @@ -156,13 +151,13 @@ class WatcherError(RuntimeError):


class Watcher:
_tasks: Dict[asyncio.Future, Optional[FixerType]]
_scheduled: List[FixerType]
_tasks: dict[asyncio.Future, FixerType | None]
_scheduled: list[FixerType]
_tasks_changed: BiDirectionalEvent
_cancelled: asyncio.Event
_cancel_timeout: float
_preexit_callbacks: List[Callable[[], None]]
_shielded_tasks: Dict[asyncio.Task, asyncio.Future]
_preexit_callbacks: list[Callable[[], None]]
_shielded_tasks: dict[asyncio.Task, asyncio.Future]
# pyre-ignore[13]: loop is initialized in __aenter__
loop: asyncio.AbstractEventLoop
running: bool
Expand All @@ -188,12 +183,12 @@ def __init__(
if context:
WATCHER_CONTEXT.set(self)
self._cancel_timeout = cancel_timeout
self._tasks: Dict[asyncio.Future, Optional[FixerType]] = {}
self._scheduled: List[FixerType] = []
self._tasks: dict[asyncio.Future, FixerType | None] = {}
self._scheduled: list[FixerType] = []
self._tasks_changed = BiDirectionalEvent()
self._cancelled = asyncio.Event()
self._preexit_callbacks = []
self._shielded_tasks: Dict[asyncio.Task, asyncio.Future] = {}
self._shielded_tasks: dict[asyncio.Task, asyncio.Future] = {}
self.running = False
self.done_ok = done_ok

Expand All @@ -213,7 +208,7 @@ async def _run_scheduled(self) -> None:
async def unwatch(
self,
task: asyncio.Task = START_TASK,
fixer: Optional[FixerType] = None,
fixer: FixerType | None = None,
*,
shield: bool = False,
) -> bool:
Expand Down Expand Up @@ -258,7 +253,7 @@ async def tasks_changed() -> None:
def watch(
self,
task: asyncio.Task = START_TASK,
fixer: Optional[FixerType] = None,
fixer: FixerType | None = None,
*,
shield: bool = False,
) -> None:
Expand Down Expand Up @@ -325,16 +320,16 @@ def _run_preexit_callbacks(self) -> None:
f"ignoring exception from pre-exit callback {callback}: {e}"
)

async def __aenter__(self) -> "Watcher":
async def __aenter__(self) -> Watcher:
WATCHER_CONTEXT.set(self)
self.loop = asyncio.get_running_loop()
return self

async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
tb: Optional[TracebackType],
exc_type: type[BaseException] | None,
exc: BaseException | None,
tb: TracebackType | None,
) -> bool:
cancel_task: asyncio.Task = self.loop.create_task(self._cancelled.wait())
changed_task: asyncio.Task = START_TASK
Expand Down Expand Up @@ -411,7 +406,7 @@ async def _handle_cancel(self) -> None:
task.cancel()

done, pending = await asyncio.wait(tasks, timeout=self._cancel_timeout)
bad_tasks: List[asyncio.Future] = []
bad_tasks: list[asyncio.Future] = []
for task in done:
if task.cancelled():
continue
Expand All @@ -433,26 +428,26 @@ async def _handle_cancel(self) -> None:
class _CountTask:
"""So herd can track herd size and task together for cancellation"""

task: Optional[asyncio.Task] = None
task: asyncio.Task | None = None
count: int = 0


def _get_local(local: threading.local, field: str) -> Dict[CacheKey, object]:
def _get_local(local: threading.local, field: str) -> dict[CacheKey, object]:
"""
helper for attempting to fetch a named attr from a threading.local
"""
try:
return cast(Dict[CacheKey, object], getattr(local, field))
return cast(dict[CacheKey, object], getattr(local, field))
except AttributeError:
container: Dict[CacheKey, object] = {}
container: dict[CacheKey, object] = {}
setattr(local, field, container)
return container


def _build_key(
args: Tuple[object, ...],
args: tuple[object, ...],
kwargs: Mapping[str, object],
ignored_args: Optional[AbstractSet[ArgID]] = None,
ignored_args: AbstractSet[ArgID] | None = None,
) -> CacheKey:
"""
Build a key for caching Hashable args and kwargs.
Expand All @@ -466,7 +461,7 @@ def _build_key(
(
tuple((value for idx, value in enumerate(args) if idx not in ignored_args)),
tuple(
(item for item in sorted(kwargs.items()) if item[0] not in ignored_args)
item for item in sorted(kwargs.items()) if item[0] not in ignored_args
),
)
)
Expand All @@ -483,7 +478,7 @@ def __call__(
def herd(
fn: Callable[TParams, Coroutine[object, object, T]],
*,
ignored_args: Optional[AbstractSet[ArgID]] = None,
ignored_args: AbstractSet[ArgID] | None = None,
) -> Callable[TParams, Coroutine[object, object, T]]: # pragma: nocover
...

Expand All @@ -492,15 +487,15 @@ def herd(
def herd(
fn: None = None,
*,
ignored_args: Optional[AbstractSet[ArgID]] = None,
ignored_args: AbstractSet[ArgID] | None = None,
) -> AsyncCallable: # pragma: nocover
...


def herd(
fn: Callable[TParams, Coroutine[object, object, T]] | None = None,
*,
ignored_args: Optional[AbstractSet[ArgID]] = None,
ignored_args: AbstractSet[ArgID] | None = None,
) -> (
Callable[TParams, Coroutine[object, object, T]]
| Callable[
Expand Down Expand Up @@ -528,7 +523,7 @@ def decorator(

@functools.wraps(fn)
async def wrapped(*args: TParams.args, **kwargs: TParams.kwargs) -> T:
pending = cast(Dict[CacheKey, _CountTask], _get_local(local, "pending"))
pending = cast(dict[CacheKey, _CountTask], _get_local(local, "pending"))
request = _build_key(tuple(args), kwargs, ignored_args)
count_task = pending.setdefault(request, _CountTask())
count_task.count += 1
Expand Down
2 changes: 1 addition & 1 deletion later/tests/unittest/test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@


# This is a place to purposefully produce leaked tasks
saved_tasks: List[asyncio.Task] = []
saved_tasks: list[asyncio.Task] = []


class TestTestCase(TestCase):
Expand Down
11 changes: 2 additions & 9 deletions later/unittest/case.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,10 @@
import sys
import unittest.mock as mock
import weakref
from collections.abc import Callable, Coroutine, Generator
from contextvars import Context
from functools import wraps
from typing import (
AbstractSet,
Callable,
Coroutine,
Generator,
Generic,
TYPE_CHECKING,
TypeVar,
)
from typing import AbstractSet, Generic, TYPE_CHECKING, TypeVar
from unittest import IsolatedAsyncioTestCase as AsyncioTestCase


Expand Down

0 comments on commit c150a41

Please sign in to comment.