44import logging
55import asyncio
66import warnings
7+ from functools import partial
78
89from types import TracebackType
910from typing import TYPE_CHECKING , Optional , Collection , Generic , Type , cast
1213
1314import h2 .config
1415import h2 .exceptions
16+ from h2 .errors import ErrorCodes
1517
1618from multidict import MultiDict
1719
2426from .metadata import Deadline , encode_grpc_message , _Metadata
2527from .metadata import encode_metadata , decode_metadata , _MetadataLike
2628from .metadata import _STATUS_DETAILS_KEY , encode_bin_value
27- from .protocol import H2Protocol , AbstractHandler
29+ from .protocol import H2Protocol , AbstractHandler , Connection
2830from .exceptions import GRPCError , ProtocolError , StreamTerminatedError
2931from .encoding .base import GRPC_CONTENT_TYPE , CodecBase , StatusDetailsCodecBase
3032from .encoding .proto import ProtoCodec , ProtoStatusDetailsCodec
@@ -493,9 +495,8 @@ def __gc_step__(self) -> None:
493495 self .__gc_collect__ ()
494496
495497
496- class Handler (_GC , AbstractHandler ):
497- __gc_interval__ = 10
498-
498+ class Handler (AbstractHandler ):
499+ connection : Connection
499500 closing = False
500501
501502 def __init__ (
@@ -511,44 +512,51 @@ def __init__(
511512 self .dispatch = dispatch
512513 self .loop = asyncio .get_event_loop ()
513514 self ._tasks : Dict ['protocol.Stream' , 'asyncio.Task[None]' ] = {}
514- self ._cancelled : Set ['asyncio.Task[None]' ] = set ()
515515
516- def __gc_collect__ (self ) -> None :
517- self ._tasks = {s : t for s , t in self ._tasks .items ()
518- if not t .done ()}
519- self ._cancelled = {t for t in self ._cancelled
520- if not t .done ()}
516+ def connection_made (self , connection : Connection ) -> None :
517+ self .connection = connection
518+
519+ def handler_done (
520+ self ,
521+ stream : 'protocol.Stream' ,
522+ _ : 'asyncio.Future[None]' ,
523+ ) -> None :
524+ self ._tasks .pop (stream )
525+ if self .closing and not self ._tasks :
526+ self .connection .close ()
521527
522528 def accept (
523529 self ,
524530 stream : 'protocol.Stream' ,
525531 headers : _Headers ,
526532 release_stream : Callable [[], Any ],
527533 ) -> None :
528- self .__gc_step__ ()
529- self ._tasks [stream ] = self .loop .create_task (request_handler (
530- self .mapping , stream , headers , self .codec ,
531- self .status_details_codec , self .dispatch , release_stream ,
532- ))
534+ if self .closing :
535+ stream .reset_nowait (ErrorCodes .REFUSED_STREAM )
536+ release_stream ()
537+ else :
538+ task = self ._tasks [stream ] = self .loop .create_task (request_handler (
539+ self .mapping , stream , headers , self .codec ,
540+ self .status_details_codec , self .dispatch , release_stream ,
541+ ))
542+ task .add_done_callback (partial (self .handler_done , stream ))
533543
534544 def cancel (self , stream : 'protocol.Stream' ) -> None :
535- task = self ._tasks .pop (stream )
536- task .cancel ()
537- self ._cancelled .add (task )
545+ self ._tasks [stream ].cancel ()
538546
539547 def close (self ) -> None :
540548 for task in self ._tasks .values ():
541549 task .cancel ()
542- self ._cancelled .update (self ._tasks .values ())
543550 self .closing = True
544551
545552 async def wait_closed (self ) -> None :
546- if self ._cancelled :
547- await asyncio .wait (self ._cancelled )
553+ if self ._tasks :
554+ await asyncio .wait (self ._tasks .values ())
555+ else :
556+ self .connection .close ()
548557
549558 def check_closed (self ) -> bool :
550- self .__gc_collect__ ()
551- return not self ._tasks and not self ._cancelled
559+ return not self ._tasks
552560
553561
554562class Server (_GC ):
@@ -737,11 +745,11 @@ async def wait_closed(self) -> None:
737745 if self ._server is None or self ._server_closed_fut is None :
738746 raise RuntimeError ('Server is not started' )
739747 await self ._server_closed_fut
740- await self ._server .wait_closed ()
741748 if self ._handlers :
742749 await asyncio .wait ({
743750 self ._loop .create_task (h .wait_closed ()) for h in self ._handlers
744751 })
752+ await self ._server .wait_closed ()
745753
746754 async def __aenter__ (self ) -> 'Server' :
747755 return self
0 commit comments