44import logging
55import asyncio
66import warnings
7+ from functools import partial
78
89from types import TracebackType
910from typing import TYPE_CHECKING , Optional , Collection , Generic , Type , cast
1213
1314import h2 .config
1415import h2 .exceptions
16+ from h2 .errors import ErrorCodes
1517
1618from multidict import MultiDict
1719
2426from .metadata import Deadline , encode_grpc_message , _Metadata
2527from .metadata import encode_metadata , decode_metadata , _MetadataLike
2628from .metadata import _STATUS_DETAILS_KEY , encode_bin_value
27- from .protocol import H2Protocol , AbstractHandler
29+ from .protocol import H2Protocol , AbstractHandler , Connection
2830from .exceptions import GRPCError , ProtocolError , StreamTerminatedError
2931from .encoding .base import GRPC_CONTENT_TYPE , CodecBase , StatusDetailsCodecBase
3032from .encoding .proto import ProtoCodec , ProtoStatusDetailsCodec
@@ -496,6 +498,7 @@ def __gc_step__(self) -> None:
496498class Handler (_GC , AbstractHandler ):
497499 __gc_interval__ = 10
498500
501+ connection : Connection
499502 closing = False
500503
501504 def __init__ (
@@ -511,44 +514,54 @@ def __init__(
511514 self .dispatch = dispatch
512515 self .loop = asyncio .get_event_loop ()
513516 self ._tasks : Dict ['protocol.Stream' , 'asyncio.Task[None]' ] = {}
514- self ._cancelled : Set ['asyncio.Task[None]' ] = set ()
515517
516518 def __gc_collect__ (self ) -> None :
517- self ._tasks = {s : t for s , t in self ._tasks .items ()
518- if not t .done ()}
519- self ._cancelled = {t for t in self ._cancelled
520- if not t .done ()}
519+ self ._tasks = {s : t for s , t in self ._tasks .items () if not t .done ()}
520+
521+ def connection_made (self , connection : Connection ) -> None :
522+ self .connection = connection
523+
524+ def handler_done (self , stream : 'protocol.Stream' , _ : Any ) -> None :
525+ self ._tasks .pop (stream , None )
526+ if not self ._tasks :
527+ self .connection .close ()
521528
522529 def accept (
523530 self ,
524531 stream : 'protocol.Stream' ,
525532 headers : _Headers ,
526533 release_stream : Callable [[], Any ],
527534 ) -> None :
528- self .__gc_step__ ()
529- self ._tasks [stream ] = self .loop .create_task (request_handler (
530- self .mapping , stream , headers , self .codec ,
531- self .status_details_codec , self .dispatch , release_stream ,
532- ))
535+ if self .closing :
536+ stream .reset_nowait (ErrorCodes .REFUSED_STREAM )
537+ release_stream ()
538+ else :
539+ self .__gc_step__ ()
540+ self ._tasks [stream ] = self .loop .create_task (request_handler (
541+ self .mapping , stream , headers , self .codec ,
542+ self .status_details_codec , self .dispatch , release_stream ,
543+ ))
533544
534545 def cancel (self , stream : 'protocol.Stream' ) -> None :
535- task = self ._tasks .pop (stream )
536- task .cancel ()
537- self ._cancelled .add (task )
546+ self ._tasks [stream ].cancel ()
538547
539548 def close (self ) -> None :
540- for task in self ._tasks .values ():
549+ self .__gc_collect__ ()
550+ for stream , task in self ._tasks .items ():
551+ task .add_done_callback (partial (self .handler_done , stream ))
541552 task .cancel ()
542- self ._cancelled .update (self ._tasks .values ())
543553 self .closing = True
544554
545555 async def wait_closed (self ) -> None :
546- if self ._cancelled :
547- await asyncio .wait (self ._cancelled )
556+ self .__gc_collect__ ()
557+ if self ._tasks :
558+ await asyncio .wait (self ._tasks .values ())
559+ else :
560+ self .connection .close ()
548561
549562 def check_closed (self ) -> bool :
550563 self .__gc_collect__ ()
551- return not self ._tasks and not self . _cancelled
564+ return not self ._tasks
552565
553566
554567class Server (_GC ):
@@ -737,11 +750,11 @@ async def wait_closed(self) -> None:
737750 if self ._server is None or self ._server_closed_fut is None :
738751 raise RuntimeError ('Server is not started' )
739752 await self ._server_closed_fut
740- await self ._server .wait_closed ()
741753 if self ._handlers :
742754 await asyncio .wait ({
743755 self ._loop .create_task (h .wait_closed ()) for h in self ._handlers
744756 })
757+ await self ._server .wait_closed ()
745758
746759 async def __aenter__ (self ) -> 'Server' :
747760 return self
0 commit comments