From 1f9624ae9770c46b252e909a44b981511da15a23 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Pierre?= Date: Mon, 23 Sep 2024 08:35:11 +1200 Subject: [PATCH] [Update] client/server: better connection pooling/reuse --- examples/client.py | 18 +++++++------- examples/helloworld.py | 7 +++++- src/py/extra/client.py | 37 +++++++++++++++++++++-------- src/py/extra/decorators.py | 13 ++++++---- src/py/extra/http/parser.py | 7 +++--- src/py/extra/server.py | 47 ++++++------------------------------- src/py/extra/utils/htmpl.py | 6 +++-- 7 files changed, 66 insertions(+), 69 deletions(-) diff --git a/examples/client.py b/examples/client.py index 17e20b9..1c636d6 100644 --- a/examples/client.py +++ b/examples/client.py @@ -7,7 +7,7 @@ async def main(path: str, host: str = "127.0.0.1", port: int = 8000, ssl: bool = print(f"Connecting to {host}:{port}{path}") # NOTE: Connection pooling does not seem to be working with pooling(idle=3600): - for _ in range(10): + for _ in range(n := 5): async for atom in HTTPClient.Request( host=host, method="GET", @@ -15,25 +15,27 @@ async def main(path: str, host: str = "127.0.0.1", port: int = 8000, ssl: bool = path=path, timeout=10.0, streaming=False, - keepalive=_ < 9, + # NOTE: If you se this to False and you get pooling, + # you'll get a Connection lost, which is expected. + keepalive=_ < n - 1, ssl=ssl, ): - print(" >>> ", atom) - await asyncio.sleep(1.0) + pass + # print(" >>> ", atom) + await asyncio.sleep(0.25) if __name__ == "__main__": import sys - args = sys.argv[1:] + args = sys.argv[1:] or ["/index"] n = len(args) - print(n, args) print( asyncio.run( main( path=args[0], - host=args[1] if n >= 1 else "127.0.0.1", - port=int(args[2]) if n >= 2 else 8000, + host=args[1] if n > 1 else "127.0.0.1", + port=int(args[2]) if n > 2 else 8000, ) ) ) diff --git a/examples/helloworld.py b/examples/helloworld.py index 3ee21e8..b007f85 100644 --- a/examples/helloworld.py +++ b/examples/helloworld.py @@ -2,9 +2,14 @@ class HelloWorld(Service): + def __init__(self): + super().__init__() + self.count: int = 0 + @on(GET="{any}") def helloWorld(self, request: HTTPRequest, any: str) -> HTTPResponse: - return request.respond(b"Hello, World !", "text/plain") + self.count += 1 + return request.respond(f"Hello, World ! #{self.count}", "text/plain") if __name__ == "__main__": diff --git a/src/py/extra/client.py b/src/py/extra/client.py index 4f30939..9fd340c 100644 --- a/src/py/extra/client.py +++ b/src/py/extra/client.py @@ -196,11 +196,12 @@ async def Connect( """Returns a connection to the target from the pool (if the pool is available and has a valid connection to the target), or creates a new connection.""" pools = cls.All.get(None) - return await ( + res = await ( pools[-1].get(target) if pools else Connection.Make(target, idle=idle, timeout=timeout, verified=verified) ) + return res @classmethod def Release(cls, connection: Connection) -> bool: @@ -215,9 +216,7 @@ def Release(cls, connection: Connection) -> bool: connection.close() return False elif pools: - # FIXME: Why are we closing the connection upon release, that - # should totally not be the case. We only close if expired. - connection.close() + # NOTE: We don't close the connection here, as we want to reuse it. pools[-1].put(connection) return True else: @@ -243,6 +242,12 @@ def __init__(self, idle: float | None = None): self.connections: dict[ConnectionTarget, list[Connection]] = {} self.idle: float | None = idle + def has( + self, + target: ConnectionTarget, + ) -> bool: + return any(_.isValid is True for _ in self.connections.get(target) or ()) + async def get( self, target: ConnectionTarget, @@ -254,6 +259,7 @@ async def get( # then we close the connection, or return a new one. while cxn: c = cxn.pop() + print("POOOL CONN", c, c.isValid) if c.isValid: return c else: @@ -334,6 +340,7 @@ async def OnRequest( timeout: float | None = 2.0, buffer: int = 32_000, streaming: bool | None = None, + keepalive: bool = False, ) -> AsyncGenerator[HTTPAtom, bool | None]: """Low level function to process HTTP requests with the given connection.""" # We send the line @@ -344,10 +351,18 @@ async def OnRequest( ) if "Host" not in head: head["Host"] = host - # if "Content-Length" not in head: - # head["Content-Length"] = "0" + if not streaming and "Content-Length" not in head: + head["Content-Length"] = ( + "0" + if body is None + else ( + str(body.length) + if isinstance(body, HTTPBodyBlob) + else str(body.expected or "0") + ) + ) if "Connection" not in head: - head["Connection"] = "close" + head["Connection"] = "keep-alive" if keepalive else "close" cxn.writer.write(line) payload = "\r\n".join(f"{k}: {v}" for k, v in head.items()).encode("ascii") cxn.writer.write(payload) @@ -361,7 +376,7 @@ async def OnRequest( read_count: int = 0 # -- # We may have more than one request in each payload when - # HTTP Pipelininig is on. + # HTTP Pipelining is on. res: HTTPResponse | None = None while status is HTTPProcessingStatus.Processing and res is None: try: @@ -437,6 +452,7 @@ async def Request( proxy: tuple[str, int] | bool | None = None, connection: Connection | None = None, streaming: bool | None = None, + keepalive: bool = False, ) -> AsyncGenerator[HTTPAtom, None]: """Somewhat high level API to perform an HTTP request.""" @@ -511,6 +527,7 @@ async def Request( cxn, timeout=timeout, streaming=streaming, + keepalive=keepalive, ): yield atom finally: @@ -521,13 +538,13 @@ async def Request( @contextmanager -def pooling(idle: float | None = None) -> Iterator[ConnectionPool]: +def pooling(idle: float | int | None = None) -> Iterator[ConnectionPool]: """Creates a context in which connections will be pooled.""" pool = ConnectionPool().Push(idle=idle) try: yield pool finally: - pool.pop() + pool.pop().release() if __name__ == "__main__": diff --git a/src/py/extra/decorators.py b/src/py/extra/decorators.py index 0693471..1a96777 100644 --- a/src/py/extra/decorators.py +++ b/src/py/extra/decorators.py @@ -1,5 +1,7 @@ from typing import ClassVar, Union, Callable, NamedTuple, TypeVar, Any, cast +from .http.model import HTTPRequest, HTTPResponse + T = TypeVar("T") @@ -37,7 +39,7 @@ def Meta(scope: Any, *, strict: bool = False) -> dict[str, Any]: return cast(dict[str, Any], getattr(scope, "__extra__")) else: if hasattr(scope, "__dict__"): - return scope.__dict__ + return cast(dict[str, Any], scope.__dict__) elif strict: raise RuntimeError(f"Metadata cannot be attached to object: {scope}") else: @@ -127,10 +129,9 @@ def decorator(function: T) -> T: meta.setdefault(Extra.ON_PRIORITY, int(priority)) json_data: Any | None = None for http_method, url in list(methods.items()): - if type(url) not in (list, tuple): - url = (url,) + urls: list[str] | tuple[str, ...] = (url,) if isinstance(url, str) else url for method in http_method.upper().split("_"): - for _ in url: + for _ in urls: if method == "JSON": json_data = _ else: @@ -174,7 +175,9 @@ def decorator(function: T, *args: Any, **kwargs: Any) -> T: return decorator -def post(transform: Callable[..., bool]) -> Callable[[T], T]: +def post( + transform: Callable[[HTTPRequest, HTTPResponse], HTTPResponse] +) -> Callable[[T], T]: """Registers the given `transform` as a post-processing step of the decorated function.""" diff --git a/src/py/extra/http/parser.py b/src/py/extra/http/parser.py index 9f82be8..210e4fc 100644 --- a/src/py/extra/http/parser.py +++ b/src/py/extra/http/parser.py @@ -245,7 +245,6 @@ def feed(self, chunk: bytes) -> Iterator[HTTPAtom]: if l is False: # We've parsed the headers headers = self.headers.flush() - line = self.requestLine self.requestHeaders = headers if headers is not None: yield headers @@ -260,6 +259,7 @@ def feed(self, chunk: bytes) -> Iterator[HTTPAtom]: and headers.contentLength == 0, ) ): + line = self.requestLine # That's an early exit yield HTTPRequest( method=line.method, @@ -280,10 +280,11 @@ def feed(self, chunk: bytes) -> Iterator[HTTPAtom]: ) yield HTTPProcessingStatus.Body elif self.parser is self.bodyEOS or self.parser is self.bodyLength: - if line is None or headers is None: + if self.requestLine is None or self.requestHeaders is None: yield HTTPProcessingStatus.BadFormat else: - headers = headers or HTTPHeaders({}) + headers = self.requestHeaders + line = self.requestLine # NOTE: This is an awkward dance around the type checker body = ( self.bodyEOS.flush() diff --git a/src/py/extra/server.py b/src/py/extra/server.py index e273579..3cfe22a 100644 --- a/src/py/extra/server.py +++ b/src/py/extra/server.py @@ -1,8 +1,6 @@ -from typing import Callable, NamedTuple, Any, Coroutine, Self +from typing import Callable, NamedTuple, Any, Coroutine import socket import asyncio -import time -from dataclasses import dataclass from .utils.logging import exception, info, warning, event from .utils.io import asWritable from .utils.limits import LimitType, unlimit @@ -60,16 +58,6 @@ class ServerOptions(NamedTuple): ) -@dataclass(slots=True) -class UnclosedSocket: - socket: socket.socket - updated: float - - def touch(self) -> Self: - self.updated = time.time() - return self - - class AIOSocketBodyReader(HTTPBodyReader): __slots__ = ["socket", "loop", "buffer"] @@ -104,7 +92,6 @@ async def OnRequest( *, loop: asyncio.AbstractEventLoop, options: ServerOptions, - unclosed: dict[int, UnclosedSocket], ) -> None: """Asynchronous worker, processing a socket in the context of an application.""" @@ -230,25 +217,9 @@ async def OnRequest( except Exception as e: exception(e) finally: - # FIXME: We should support keep-alive, where we don't close the - # connection right away. However the drawback is that each worker - # is going to linger for longer, waiting for the reader to timeout. - # By default, connections in HTTP/1.1 are keep alive. - # -- - # NOTE: We've exited the keep alive loop, so here we close the client - # connection. - # DEBUG - if keep_alive: - info("Keeping connection alive") - print("TODO SOCKET", client) - # if client.fd in unclosed: - # unclosed[client.fd].touch() - # else: - # unclosed[client.fd] = UnclosedSocket(client, time.time()) - # FIXME: We should add the socket to a queue of lingering sockets - else: - info("Closing connection") - client.close() + # NOTE: The above loop takes care of keep alive, so we always close + # the connection on exit. + client.close() @staticmethod async def SendResponse( @@ -359,15 +330,13 @@ async def Serve( Port=options.port, ) - unclosed: dict[int, UnclosedSocket] = {} - # TODO: Add condition try: while True: if options.condition and not options.condition(): break try: - res = ( - await asyncio.wait_for( + res = await ( + asyncio.wait_for( loop.sock_accept(server), timeout=options.timeout ) if options.timeout @@ -379,9 +348,7 @@ async def Serve( client = res[0] # NOTE: Should do something with the tasks task = loop.create_task( - cls.OnRequest( - app, client, loop=loop, options=options, unclosed=unclosed - ) + cls.OnRequest(app, client, loop=loop, options=options) ) tasks.add(task) task.add_done_callback(tasks.discard) diff --git a/src/py/extra/utils/htmpl.py b/src/py/extra/utils/htmpl.py index 27af14d..19b8884 100644 --- a/src/py/extra/utils/htmpl.py +++ b/src/py/extra/utils/htmpl.py @@ -173,7 +173,9 @@ def f(*children: TNodeContent, **attributes: TAttributeContent) -> Node: v if isinstance(v, list) else ( - [_ for _ in cast(tuple, v)] if isinstance(v, tuple) else [v] + [_ for _ in cast(tuple[TNodeContent], v)] + if isinstance(v, tuple) + else [v] ) ), ) @@ -210,7 +212,7 @@ def __init__(self, name: str, factories: dict[str, NodeFactory]): def __getattribute__(self, name: str) -> Callable[..., Node]: if name.startswith("_"): - return super().__getattribute__(name) + return cast(Callable[..., Node], super().__getattribute__(name)) else: factories = self._factories if name not in factories: