Skip to content

Commit 777d6c3

Browse files
committed
Refactored server close logic to gracefully exit without using GOAWAY frames
1 parent 5916cba commit 777d6c3

File tree

5 files changed

+46
-30
lines changed

5 files changed

+46
-30
lines changed

grpclib/client.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,10 @@
6262

6363

6464
class Handler(AbstractHandler):
65-
connection_lost = False
65+
closing = False
66+
67+
def connection_made(self, connection: Any) -> None:
68+
pass
6669

6770
def accept(self, stream: Any, headers: Any, release_stream: Any) -> None:
6871
raise NotImplementedError('Client connection can not accept requests')
@@ -71,7 +74,7 @@ def cancel(self, stream: Any) -> None:
7174
pass
7275

7376
def close(self) -> None:
74-
self.connection_lost = True
77+
self.closing = True
7578

7679

7780
class Stream(StreamIterator[_RecvType], Generic[_SendType, _RecvType]):
@@ -737,7 +740,7 @@ async def _create_connection(self) -> H2Protocol:
737740
@property
738741
def _connected(self) -> bool:
739742
return (self._protocol is not None
740-
and not self._protocol.handler.connection_lost)
743+
and not cast(Handler, self._protocol.handler).closing)
741744

742745
async def __connect__(self) -> H2Protocol:
743746
if not self._connected:

grpclib/protocol.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -488,6 +488,10 @@ def closable(self) -> bool:
488488

489489
class AbstractHandler(ABC):
490490

491+
@abstractmethod
492+
def connection_made(self, connection: Connection) -> None:
493+
pass
494+
491495
@abstractmethod
492496
def accept(
493497
self,
@@ -709,6 +713,7 @@ def connection_made(self, transport: BaseTransport) -> None:
709713
self.connection.flush()
710714
self.connection.initialize()
711715

716+
self.handler.connection_made(self.connection)
712717
self.processor = EventsProcessor(self.handler, self.connection)
713718

714719
def data_received(self, data: bytes) -> None:

grpclib/server.py

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import logging
55
import asyncio
66
import warnings
7+
from functools import partial
78

89
from types import TracebackType
910
from typing import TYPE_CHECKING, Optional, Collection, Generic, Type, cast
@@ -12,6 +13,7 @@
1213

1314
import h2.config
1415
import h2.exceptions
16+
from h2.errors import ErrorCodes
1517

1618
from multidict import MultiDict
1719

@@ -24,7 +26,7 @@
2426
from .metadata import Deadline, encode_grpc_message, _Metadata
2527
from .metadata import encode_metadata, decode_metadata, _MetadataLike
2628
from .metadata import _STATUS_DETAILS_KEY, encode_bin_value
27-
from .protocol import H2Protocol, AbstractHandler
29+
from .protocol import H2Protocol, AbstractHandler, Connection
2830
from .exceptions import GRPCError, ProtocolError, StreamTerminatedError
2931
from .encoding.base import GRPC_CONTENT_TYPE, CodecBase, StatusDetailsCodecBase
3032
from .encoding.proto import ProtoCodec, ProtoStatusDetailsCodec
@@ -493,9 +495,8 @@ def __gc_step__(self) -> None:
493495
self.__gc_collect__()
494496

495497

496-
class Handler(_GC, AbstractHandler):
497-
__gc_interval__ = 10
498-
498+
class Handler(AbstractHandler):
499+
connection: Connection
499500
closing = False
500501

501502
def __init__(
@@ -511,44 +512,51 @@ def __init__(
511512
self.dispatch = dispatch
512513
self.loop = asyncio.get_event_loop()
513514
self._tasks: Dict['protocol.Stream', 'asyncio.Task[None]'] = {}
514-
self._cancelled: Set['asyncio.Task[None]'] = set()
515515

516-
def __gc_collect__(self) -> None:
517-
self._tasks = {s: t for s, t in self._tasks.items()
518-
if not t.done()}
519-
self._cancelled = {t for t in self._cancelled
520-
if not t.done()}
516+
def connection_made(self, connection: Connection) -> None:
517+
self.connection = connection
518+
519+
def handler_done(
520+
self,
521+
stream: 'protocol.Stream',
522+
_: 'asyncio.Future[None]',
523+
) -> None:
524+
self._tasks.pop(stream)
525+
if self.closing and not self._tasks:
526+
self.connection.close()
521527

522528
def accept(
523529
self,
524530
stream: 'protocol.Stream',
525531
headers: _Headers,
526532
release_stream: Callable[[], Any],
527533
) -> None:
528-
self.__gc_step__()
529-
self._tasks[stream] = self.loop.create_task(request_handler(
530-
self.mapping, stream, headers, self.codec,
531-
self.status_details_codec, self.dispatch, release_stream,
532-
))
534+
if self.closing:
535+
stream.reset_nowait(ErrorCodes.REFUSED_STREAM)
536+
release_stream()
537+
else:
538+
task = self._tasks[stream] = self.loop.create_task(request_handler(
539+
self.mapping, stream, headers, self.codec,
540+
self.status_details_codec, self.dispatch, release_stream,
541+
))
542+
task.add_done_callback(partial(self.handler_done, stream))
533543

534544
def cancel(self, stream: 'protocol.Stream') -> None:
535-
task = self._tasks.pop(stream)
536-
task.cancel()
537-
self._cancelled.add(task)
545+
self._tasks[stream].cancel()
538546

539547
def close(self) -> None:
540548
for task in self._tasks.values():
541549
task.cancel()
542-
self._cancelled.update(self._tasks.values())
543550
self.closing = True
544551

545552
async def wait_closed(self) -> None:
546-
if self._cancelled:
547-
await asyncio.wait(self._cancelled)
553+
if self._tasks:
554+
await asyncio.wait(self._tasks.values())
555+
else:
556+
self.connection.close()
548557

549558
def check_closed(self) -> bool:
550-
self.__gc_collect__()
551-
return not self._tasks and not self._cancelled
559+
return not self._tasks
552560

553561

554562
class Server(_GC):
@@ -737,11 +745,11 @@ async def wait_closed(self) -> None:
737745
if self._server is None or self._server_closed_fut is None:
738746
raise RuntimeError('Server is not started')
739747
await self._server_closed_fut
740-
await self._server.wait_closed()
741748
if self._handlers:
742749
await asyncio.wait({
743750
self._loop.create_task(h.wait_closed()) for h in self._handlers
744751
})
752+
await self._server.wait_closed()
745753

746754
async def __aenter__(self) -> 'Server':
747755
return self

tests/stubs.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ class DummyHandler(AbstractHandler):
4747
headers = None
4848
release_stream = None
4949

50+
def connection_made(self, connection):
51+
pass
52+
5053
def accept(self, stream, headers, release_stream):
5154
self.stream = stream
5255
self.headers = headers

tests/test_memory.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,14 +83,11 @@ async def test_stream():
8383
cs = ClientServer(DummyService, DummyServiceStub)
8484
async with cs as (_, stub):
8585
await stub.UnaryUnary(DummyRequest(value='ping'))
86-
handler = next(iter(cs.server._handlers))
87-
handler.__gc_collect__()
8886
gc.collect()
8987
gc.disable()
9088
try:
9189
pre = set(collect())
9290
await stub.UnaryUnary(DummyRequest(value='ping'))
93-
handler.__gc_collect__()
9491
post = collect()
9592

9693
diff = set(post).difference(pre)

0 commit comments

Comments
 (0)