Skip to content

Commit 6cd97ca

Browse files
committed
Refactored server close logic to gracefully exit without using GOAWAY frames
1 parent 5916cba commit 6cd97ca

File tree

4 files changed

+47
-23
lines changed

4 files changed

+47
-23
lines changed

grpclib/client.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,10 @@
6262

6363

6464
class Handler(AbstractHandler):
65-
connection_lost = False
65+
closing = False
66+
67+
def connection_made(self, connection: Any) -> None:
68+
pass
6669

6770
def accept(self, stream: Any, headers: Any, release_stream: Any) -> None:
6871
raise NotImplementedError('Client connection can not accept requests')
@@ -71,7 +74,7 @@ def cancel(self, stream: Any) -> None:
7174
pass
7275

7376
def close(self) -> None:
74-
self.connection_lost = True
77+
self.closing = True
7578

7679

7780
class Stream(StreamIterator[_RecvType], Generic[_SendType, _RecvType]):
@@ -737,7 +740,7 @@ async def _create_connection(self) -> H2Protocol:
737740
@property
738741
def _connected(self) -> bool:
739742
return (self._protocol is not None
740-
and not self._protocol.handler.connection_lost)
743+
and not cast(Handler, self._protocol.handler).closing)
741744

742745
async def __connect__(self) -> H2Protocol:
743746
if not self._connected:

grpclib/protocol.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -488,6 +488,10 @@ def closable(self) -> bool:
488488

489489
class AbstractHandler(ABC):
490490

491+
@abstractmethod
492+
def connection_made(self, connection: Connection) -> None:
493+
pass
494+
491495
@abstractmethod
492496
def accept(
493497
self,
@@ -709,6 +713,7 @@ def connection_made(self, transport: BaseTransport) -> None:
709713
self.connection.flush()
710714
self.connection.initialize()
711715

716+
self.handler.connection_made(self.connection)
712717
self.processor = EventsProcessor(self.handler, self.connection)
713718

714719
def data_received(self, data: bytes) -> None:

grpclib/server.py

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import logging
55
import asyncio
66
import warnings
7+
from functools import partial
78

89
from types import TracebackType
910
from typing import TYPE_CHECKING, Optional, Collection, Generic, Type, cast
@@ -12,6 +13,7 @@
1213

1314
import h2.config
1415
import h2.exceptions
16+
from h2.errors import ErrorCodes
1517

1618
from multidict import MultiDict
1719

@@ -24,7 +26,7 @@
2426
from .metadata import Deadline, encode_grpc_message, _Metadata
2527
from .metadata import encode_metadata, decode_metadata, _MetadataLike
2628
from .metadata import _STATUS_DETAILS_KEY, encode_bin_value
27-
from .protocol import H2Protocol, AbstractHandler
29+
from .protocol import H2Protocol, AbstractHandler, Connection
2830
from .exceptions import GRPCError, ProtocolError, StreamTerminatedError
2931
from .encoding.base import GRPC_CONTENT_TYPE, CodecBase, StatusDetailsCodecBase
3032
from .encoding.proto import ProtoCodec, ProtoStatusDetailsCodec
@@ -496,6 +498,7 @@ def __gc_step__(self) -> None:
496498
class Handler(_GC, AbstractHandler):
497499
__gc_interval__ = 10
498500

501+
connection: Connection
499502
closing = False
500503

501504
def __init__(
@@ -511,44 +514,54 @@ def __init__(
511514
self.dispatch = dispatch
512515
self.loop = asyncio.get_event_loop()
513516
self._tasks: Dict['protocol.Stream', 'asyncio.Task[None]'] = {}
514-
self._cancelled: Set['asyncio.Task[None]'] = set()
515517

516518
def __gc_collect__(self) -> None:
517-
self._tasks = {s: t for s, t in self._tasks.items()
518-
if not t.done()}
519-
self._cancelled = {t for t in self._cancelled
520-
if not t.done()}
519+
self._tasks = {s: t for s, t in self._tasks.items() if not t.done()}
520+
521+
def connection_made(self, connection: Connection) -> None:
522+
self.connection = connection
523+
524+
def handler_done(self, stream: 'protocol.Stream', _: Any) -> None:
525+
self._tasks.pop(stream, None)
526+
if not self._tasks:
527+
self.connection.close()
521528

522529
def accept(
523530
self,
524531
stream: 'protocol.Stream',
525532
headers: _Headers,
526533
release_stream: Callable[[], Any],
527534
) -> None:
528-
self.__gc_step__()
529-
self._tasks[stream] = self.loop.create_task(request_handler(
530-
self.mapping, stream, headers, self.codec,
531-
self.status_details_codec, self.dispatch, release_stream,
532-
))
535+
if self.closing:
536+
stream.reset_nowait(ErrorCodes.REFUSED_STREAM)
537+
release_stream()
538+
else:
539+
self.__gc_step__()
540+
self._tasks[stream] = self.loop.create_task(request_handler(
541+
self.mapping, stream, headers, self.codec,
542+
self.status_details_codec, self.dispatch, release_stream,
543+
))
533544

534545
def cancel(self, stream: 'protocol.Stream') -> None:
535-
task = self._tasks.pop(stream)
536-
task.cancel()
537-
self._cancelled.add(task)
546+
self._tasks[stream].cancel()
538547

539548
def close(self) -> None:
540-
for task in self._tasks.values():
549+
self.__gc_collect__()
550+
for stream, task in self._tasks.items():
551+
task.add_done_callback(partial(self.handler_done, stream))
541552
task.cancel()
542-
self._cancelled.update(self._tasks.values())
543553
self.closing = True
544554

545555
async def wait_closed(self) -> None:
546-
if self._cancelled:
547-
await asyncio.wait(self._cancelled)
556+
self.__gc_collect__()
557+
if self._tasks:
558+
await asyncio.wait(self._tasks.values())
559+
else:
560+
self.connection.close()
548561

549562
def check_closed(self) -> bool:
550563
self.__gc_collect__()
551-
return not self._tasks and not self._cancelled
564+
return not self._tasks
552565

553566

554567
class Server(_GC):
@@ -737,11 +750,11 @@ async def wait_closed(self) -> None:
737750
if self._server is None or self._server_closed_fut is None:
738751
raise RuntimeError('Server is not started')
739752
await self._server_closed_fut
740-
await self._server.wait_closed()
741753
if self._handlers:
742754
await asyncio.wait({
743755
self._loop.create_task(h.wait_closed()) for h in self._handlers
744756
})
757+
await self._server.wait_closed()
745758

746759
async def __aenter__(self) -> 'Server':
747760
return self

tests/stubs.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ class DummyHandler(AbstractHandler):
4747
headers = None
4848
release_stream = None
4949

50+
def connection_made(self, connection):
51+
pass
52+
5053
def accept(self, stream, headers, release_stream):
5154
self.stream = stream
5255
self.headers = headers

0 commit comments

Comments
 (0)