From a8b6572ff598356034955111750c1fa99497c5f3 Mon Sep 17 00:00:00 2001 From: Sehat1137 <29227141+Sehat1137@users.noreply.github.com> Date: Mon, 23 Sep 2024 23:21:22 +0300 Subject: [PATCH] feat: add CLI support for AsgiFastStream (#1782) * feat: add CLI support for AsgiFastStream * tests: ignore signal test at Windows * Update asgi.md * chore: bump version --------- Co-authored-by: sehat1137 Co-authored-by: Nikita Pastukhov Co-authored-by: Pastukhov Nikita --- docs/docs/en/getting-started/asgi.md | 31 ++++ faststream/__about__.py | 2 +- faststream/_internal/__init__.py | 0 faststream/_internal/application.py | 207 +++++++++++++++++++++++++++ faststream/app.py | 197 ++----------------------- faststream/asgi/app.py | 63 +++++++- faststream/cli/main.py | 22 +-- tests/cli/rabbit/test_app.py | 80 +++++++---- tests/cli/test_run.py | 85 +++++++++++ 9 files changed, 463 insertions(+), 224 deletions(-) create mode 100644 faststream/_internal/__init__.py create mode 100644 faststream/_internal/application.py create mode 100644 tests/cli/test_run.py diff --git a/docs/docs/en/getting-started/asgi.md b/docs/docs/en/getting-started/asgi.md index 9006c37e23..512475ccee 100644 --- a/docs/docs/en/getting-started/asgi.md +++ b/docs/docs/en/getting-started/asgi.md @@ -114,6 +114,37 @@ app = AsgiFastStream( Now, your **AsyncAPI HTML** representation can be found by the `/docs` url. +### FastStream object reusage + +You may also use regular `FastStream` application object for similar result + +```python linenums="1" hl_lines="2 9" +from faststream import FastStream +from faststream.nats import NatsBroker +from faststream.asgi import make_ping_asgi, AsgiResponse + +broker = NatsBroker() + +async def liveness_ping(scope, receive, send): + return AsgiResponse(b"", status_code=200) + + +app = FastStream(broker).as_asgi( + asgi_routes=[ + ("/liveness", liveness_ping), + ("/readiness", make_ping_asgi(broker, timeout=5.0)), + ], + asyncapi_path="/docs", +) +``` + +``` tip + For app which use ASGI you may use cli command like for default FastStream app + + ```shell + faststream run main:app --host 0.0.0.0 --port 8000 --workers 4 + ``` + ## Other ASGI Compatibility Moreover, our wrappers can be used as ready-to-use endpoins for other **ASGI** frameworks. This can be very helpful When you are running **FastStream** in the same runtime as any other **ASGI** frameworks. diff --git a/faststream/__about__.py b/faststream/__about__.py index df40fc6c6a..67ba1f9f2e 100644 --- a/faststream/__about__.py +++ b/faststream/__about__.py @@ -1,5 +1,5 @@ """Simple and fast framework to create message brokers based microservices.""" -__version__ = "0.5.23" +__version__ = "0.5.24" SERVICE_NAME = f"faststream-{__version__}" diff --git a/faststream/_internal/__init__.py b/faststream/_internal/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/faststream/_internal/application.py b/faststream/_internal/application.py new file mode 100644 index 0000000000..dd0140db4d --- /dev/null +++ b/faststream/_internal/application.py @@ -0,0 +1,207 @@ +import logging +from abc import ABC, abstractmethod +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Optional, + Sequence, + TypeVar, + Union, +) + +import anyio +from typing_extensions import ParamSpec + +from faststream.asyncapi.proto import AsyncAPIApplication +from faststream.log.logging import logger +from faststream.utils import apply_types, context +from faststream.utils.functions import drop_response_type, fake_context, to_async + +P_HookParams = ParamSpec("P_HookParams") +T_HookReturn = TypeVar("T_HookReturn") + + +if TYPE_CHECKING: + from faststream.asyncapi.schema import ( + Contact, + ContactDict, + ExternalDocs, + ExternalDocsDict, + License, + LicenseDict, + Tag, + TagDict, + ) + from faststream.broker.core.usecase import BrokerUsecase + from faststream.types import ( + AnyDict, + AnyHttpUrl, + AsyncFunc, + Lifespan, + LoggerProto, + SettingField, + ) + + +class Application(ABC, AsyncAPIApplication): + def __init__( + self, + broker: Optional["BrokerUsecase[Any, Any]"] = None, + logger: Optional["LoggerProto"] = logger, + lifespan: Optional["Lifespan"] = None, + # AsyncAPI args, + title: str = "FastStream", + version: str = "0.1.0", + description: str = "", + terms_of_service: Optional["AnyHttpUrl"] = None, + license: Optional[Union["License", "LicenseDict", "AnyDict"]] = None, + contact: Optional[Union["Contact", "ContactDict", "AnyDict"]] = None, + tags: Optional[Sequence[Union["Tag", "TagDict", "AnyDict"]]] = None, + external_docs: Optional[ + Union["ExternalDocs", "ExternalDocsDict", "AnyDict"] + ] = None, + identifier: Optional[str] = None, + on_startup: Sequence[Callable[P_HookParams, T_HookReturn]] = (), + after_startup: Sequence[Callable[P_HookParams, T_HookReturn]] = (), + on_shutdown: Sequence[Callable[P_HookParams, T_HookReturn]] = (), + after_shutdown: Sequence[Callable[P_HookParams, T_HookReturn]] = (), + ) -> None: + context.set_global("app", self) + + self._should_exit = anyio.Event() + self.broker = broker + self.logger = logger + self.context = context + + self._on_startup_calling: List[AsyncFunc] = [ + apply_types(to_async(x)) for x in on_startup + ] + self._after_startup_calling: List[AsyncFunc] = [ + apply_types(to_async(x)) for x in after_startup + ] + self._on_shutdown_calling: List[AsyncFunc] = [ + apply_types(to_async(x)) for x in on_shutdown + ] + self._after_shutdown_calling: List[AsyncFunc] = [ + apply_types(to_async(x)) for x in after_shutdown + ] + + if lifespan is not None: + self.lifespan_context = apply_types( + func=lifespan, wrap_model=drop_response_type + ) + else: + self.lifespan_context = fake_context + + # AsyncAPI information + self.title = title + self.version = version + self.description = description + self.terms_of_service = terms_of_service + self.license = license + self.contact = contact + self.identifier = identifier + self.asyncapi_tags = tags + self.external_docs = external_docs + + @abstractmethod + async def run( + self, + log_level: int, + run_extra_options: Optional[Dict[str, "SettingField"]] = None, + sleep_time: float = 0.1, + ) -> None: ... + + def set_broker(self, broker: "BrokerUsecase[Any, Any]") -> None: + """Set already existed App object broker. + + Useful then you create/init broker in `on_startup` hook. + """ + self.broker = broker + + def on_startup( + self, + func: Callable[P_HookParams, T_HookReturn], + ) -> Callable[P_HookParams, T_HookReturn]: + """Add hook running BEFORE broker connected. + + This hook also takes an extra CLI options as a kwargs. + """ + self._on_startup_calling.append(apply_types(to_async(func))) + return func + + def on_shutdown( + self, + func: Callable[P_HookParams, T_HookReturn], + ) -> Callable[P_HookParams, T_HookReturn]: + """Add hook running BEFORE broker disconnected.""" + self._on_shutdown_calling.append(apply_types(to_async(func))) + return func + + def after_startup( + self, + func: Callable[P_HookParams, T_HookReturn], + ) -> Callable[P_HookParams, T_HookReturn]: + """Add hook running AFTER broker connected.""" + self._after_startup_calling.append(apply_types(to_async(func))) + return func + + def after_shutdown( + self, + func: Callable[P_HookParams, T_HookReturn], + ) -> Callable[P_HookParams, T_HookReturn]: + """Add hook running AFTER broker disconnected.""" + self._after_shutdown_calling.append(apply_types(to_async(func))) + return func + + def exit(self) -> None: + """Stop application manually.""" + self._should_exit.set() + + async def start( + self, + **run_extra_options: "SettingField", + ) -> None: + """Executes startup hooks and start broker.""" + for func in self._on_startup_calling: + await func(**run_extra_options) + + if self.broker is not None: + await self.broker.start() + + for func in self._after_startup_calling: + await func() + + async def stop(self) -> None: + """Executes shutdown hooks and stop broker.""" + for func in self._on_shutdown_calling: + await func() + + if self.broker is not None: + await self.broker.close() + + for func in self._after_shutdown_calling: + await func() + + async def _startup( + self, + log_level: int = logging.INFO, + run_extra_options: Optional[Dict[str, "SettingField"]] = None, + ) -> None: + self._log(log_level, "FastStream app starting...") + await self.start(**(run_extra_options or {})) + self._log( + log_level, "FastStream app started successfully! To exit, press CTRL+C" + ) + + async def _shutdown(self, log_level: int = logging.INFO) -> None: + self._log(log_level, "FastStream app shutting down...") + await self.stop() + self._log(log_level, "FastStream app shut down gracefully.") + + def _log(self, level: int, message: str) -> None: + if self.logger is not None: + self.logger.log(level, message) diff --git a/faststream/app.py b/faststream/app.py index a33baa3fee..3f6c2d546f 100644 --- a/faststream/app.py +++ b/faststream/app.py @@ -1,162 +1,35 @@ -import logging.config +import logging from typing import ( TYPE_CHECKING, - Any, AsyncIterator, - Callable, Dict, - List, Optional, Sequence, + Tuple, TypeVar, - Union, ) import anyio from typing_extensions import Annotated, ParamSpec, deprecated from faststream._compat import ExceptionGroup -from faststream.asyncapi.proto import AsyncAPIApplication +from faststream._internal.application import Application +from faststream.asgi.app import AsgiFastStream from faststream.cli.supervisors.utils import set_exit from faststream.exceptions import ValidationError -from faststream.log.logging import logger -from faststream.utils import apply_types, context -from faststream.utils.functions import drop_response_type, fake_context, to_async +from faststream.utils.functions import fake_context P_HookParams = ParamSpec("P_HookParams") T_HookReturn = TypeVar("T_HookReturn") if TYPE_CHECKING: - from faststream.asyncapi.schema import ( - Contact, - ContactDict, - ExternalDocs, - ExternalDocsDict, - License, - LicenseDict, - Tag, - TagDict, - ) - from faststream.broker.core.usecase import BrokerUsecase - from faststream.types import ( - AnyCallable, - AnyDict, - AnyHttpUrl, - AsyncFunc, - Lifespan, - LoggerProto, - SettingField, - ) - - -class FastStream(AsyncAPIApplication): - """A class representing a FastStream application.""" - - _on_startup_calling: List["AsyncFunc"] - _after_startup_calling: List["AsyncFunc"] - _on_shutdown_calling: List["AsyncFunc"] - _after_shutdown_calling: List["AsyncFunc"] - - def __init__( - self, - broker: Optional["BrokerUsecase[Any, Any]"] = None, - logger: Optional["LoggerProto"] = logger, - lifespan: Optional["Lifespan"] = None, - # AsyncAPI args, - title: str = "FastStream", - version: str = "0.1.0", - description: str = "", - terms_of_service: Optional["AnyHttpUrl"] = None, - license: Optional[Union["License", "LicenseDict", "AnyDict"]] = None, - contact: Optional[Union["Contact", "ContactDict", "AnyDict"]] = None, - tags: Optional[Sequence[Union["Tag", "TagDict", "AnyDict"]]] = None, - external_docs: Optional[ - Union["ExternalDocs", "ExternalDocsDict", "AnyDict"] - ] = None, - identifier: Optional[str] = None, - on_startup: Sequence["AnyCallable"] = (), - after_startup: Sequence["AnyCallable"] = (), - on_shutdown: Sequence["AnyCallable"] = (), - after_shutdown: Sequence["AnyCallable"] = (), - # all options should be copied to AsgiFastStream class - ) -> None: - context.set_global("app", self) - - self.broker = broker - self.logger = logger - self.context = context - - self._on_startup_calling = [apply_types(to_async(x)) for x in on_startup] - self._after_startup_calling = [apply_types(to_async(x)) for x in after_startup] - self._on_shutdown_calling = [apply_types(to_async(x)) for x in on_shutdown] - self._after_shutdown_calling = [ - apply_types(to_async(x)) for x in after_shutdown - ] - - self.lifespan_context = ( - apply_types( - func=lifespan, - wrap_model=drop_response_type, - ) - if lifespan is not None - else fake_context - ) - - self._should_exit = anyio.Event() - - # AsyncAPI information - self.title = title - self.version = version - self.description = description - self.terms_of_service = terms_of_service - self.license = license - self.contact = contact - self.identifier = identifier - self.asyncapi_tags = tags - self.external_docs = external_docs - - def set_broker(self, broker: "BrokerUsecase[Any, Any]") -> None: - """Set already existed App object broker. - - Useful then you create/init broker in `on_startup` hook. - """ - self.broker = broker - - def on_startup( - self, - func: Callable[P_HookParams, T_HookReturn], - ) -> Callable[P_HookParams, T_HookReturn]: - """Add hook running BEFORE broker connected. + from faststream.asgi.types import ASGIApp + from faststream.types import SettingField - This hook also takes an extra CLI options as a kwargs. - """ - self._on_startup_calling.append(apply_types(to_async(func))) - return func - def on_shutdown( - self, - func: Callable[P_HookParams, T_HookReturn], - ) -> Callable[P_HookParams, T_HookReturn]: - """Add hook running BEFORE broker disconnected.""" - self._on_shutdown_calling.append(apply_types(to_async(func))) - return func - - def after_startup( - self, - func: Callable[P_HookParams, T_HookReturn], - ) -> Callable[P_HookParams, T_HookReturn]: - """Add hook running AFTER broker connected.""" - self._after_startup_calling.append(apply_types(to_async(func))) - return func - - def after_shutdown( - self, - func: Callable[P_HookParams, T_HookReturn], - ) -> Callable[P_HookParams, T_HookReturn]: - """Add hook running AFTER broker disconnected.""" - self._after_shutdown_calling.append(apply_types(to_async(func))) - return func +class FastStream(Application): + """A class representing a FastStream application.""" async def run( self, @@ -188,54 +61,12 @@ async def run( for ex in e.exceptions: raise ex from None - def exit(self) -> None: - """Stop application manually.""" - self._should_exit.set() - - async def start( - self, - **run_extra_options: "SettingField", - ) -> None: - """Executes startup hooks and start broker.""" - for func in self._on_startup_calling: - await func(**run_extra_options) - - if self.broker is not None: - await self.broker.start() - - for func in self._after_startup_calling: - await func() - - async def stop(self) -> None: - """Executes shutdown hooks and stop broker.""" - for func in self._on_shutdown_calling: - await func() - - if self.broker is not None: - await self.broker.close() - - for func in self._after_shutdown_calling: - await func() - - async def _startup( + def as_asgi( self, - log_level: int = logging.INFO, - run_extra_options: Optional[Dict[str, "SettingField"]] = None, - ) -> None: - self._log(log_level, "FastStream app starting...") - await self.start(**(run_extra_options or {})) - self._log( - log_level, "FastStream app started successfully! To exit, press CTRL+C" - ) - - async def _shutdown(self, log_level: int = logging.INFO) -> None: - self._log(log_level, "FastStream app shutting down...") - await self.stop() - self._log(log_level, "FastStream app shut down gracefully.") - - def _log(self, level: int, message: str) -> None: - if self.logger is not None: - self.logger.log(level, message) + asgi_routes: Sequence[Tuple[str, "ASGIApp"]] = (), + asyncapi_path: Optional[str] = None, + ) -> AsgiFastStream: + return AsgiFastStream.from_app(self, asgi_routes, asyncapi_path) try: diff --git a/faststream/asgi/app.py b/faststream/asgi/app.py index c91a4e0517..36685f23fe 100644 --- a/faststream/asgi/app.py +++ b/faststream/asgi/app.py @@ -1,9 +1,11 @@ +import logging import traceback from contextlib import asynccontextmanager from typing import ( TYPE_CHECKING, Any, AsyncIterator, + Dict, Optional, Sequence, Tuple, @@ -12,7 +14,7 @@ import anyio -from faststream.app import FastStream +from faststream._internal.application import Application from faststream.asgi.factories import make_asyncapi_asgi from faststream.asgi.response import AsgiResponse from faststream.asgi.websocket import WebSocketClose @@ -37,10 +39,11 @@ AnyHttpUrl, Lifespan, LoggerProto, + SettingField, ) -class AsgiFastStream(FastStream): +class AsgiFastStream(Application): def __init__( self, broker: Optional["BrokerUsecase[Any, Any]"] = None, @@ -90,6 +93,36 @@ def __init__( if asyncapi_path: self.mount(asyncapi_path, make_asyncapi_asgi(self)) + @classmethod + def from_app( + cls, + app: Application, + asgi_routes: Sequence[Tuple[str, "ASGIApp"]], + asyncapi_path: Optional[str] = None, + ) -> "AsgiFastStream": + asgi_app = cls( + app.broker, + asgi_routes=asgi_routes, + asyncapi_path=asyncapi_path, + logger=app.logger, + lifespan=None, + title=app.title, + version=app.version, + description=app.description, + terms_of_service=app.terms_of_service, + license=app.license, + contact=app.contact, + tags=app.asyncapi_tags, + external_docs=app.external_docs, + identifier=app.identifier, + ) + asgi_app.lifespan_context = app.lifespan_context + asgi_app._on_startup_calling = app._on_startup_calling + asgi_app._after_startup_calling = app._after_startup_calling + asgi_app._on_shutdown_calling = app._on_shutdown_calling + asgi_app._after_shutdown_calling = app._after_shutdown_calling + return asgi_app + def mount(self, path: str, route: "ASGIApp") -> None: self.routes.append((path, route)) @@ -107,6 +140,30 @@ async def __call__(self, scope: "Scope", receive: "Receive", send: "Send") -> No await self.not_found(scope, receive, send) return + async def run( + self, + log_level: int = logging.INFO, + run_extra_options: Optional[Dict[str, "SettingField"]] = None, + sleep_time: float = 0.1, + ) -> None: + import uvicorn + + if not run_extra_options: + run_extra_options = {} + port = int(run_extra_options.pop("port", 8000)) # type: ignore[arg-type] + workers = int(run_extra_options.pop("workers", 1)) # type: ignore[arg-type] + host = str(run_extra_options.pop("host", "localhost")) + config = uvicorn.Config( + self, + host=host, + port=port, + log_level=log_level, + workers=workers, + **run_extra_options, + ) + server = uvicorn.Server(config) + await server.serve() + @asynccontextmanager async def start_lifespan_context(self) -> AsyncIterator[None]: async with anyio.create_task_group() as tg, self.lifespan_context(): @@ -141,7 +198,7 @@ async def lifespan(self, scope: "Scope", receive: "Receive", send: "Send") -> No await send({"type": "lifespan.shutdown.complete"}) async def not_found(self, scope: "Scope", receive: "Receive", send: "Send") -> None: - not_found_msg = "FastStream doesn't support regular HTTP protocol." + not_found_msg = "App doesn't support regular HTTP protocol." if scope["type"] == "websocket": websocket_close = WebSocketClose( diff --git a/faststream/cli/main.py b/faststream/cli/main.py index 900a36d810..36d321ab8d 100644 --- a/faststream/cli/main.py +++ b/faststream/cli/main.py @@ -11,6 +11,7 @@ from faststream import FastStream from faststream.__about__ import __version__ +from faststream._internal.application import Application from faststream.cli.docs.app import docs_app from faststream.cli.utils.imports import import_from_string from faststream.cli.utils.logs import LogLevels, get_log_level, set_log_level @@ -109,6 +110,7 @@ def run( app, extra = parse_cli_args(app, *ctx.args) casted_log_level = get_log_level(log_level) + module_path, app_obj = import_from_string(app) if app_dir: # pragma: no branch sys.path.insert(0, app_dir) @@ -126,8 +128,6 @@ def run( _run(*args) else: - module_path, _ = import_from_string(app) - if app_dir != ".": reload_dirs = [str(module_path), app_dir] else: @@ -142,11 +142,15 @@ def run( elif workers > 1: from faststream.cli.supervisors.multiprocess import Multiprocess - Multiprocess( - target=_run, - args=(*args, logging.DEBUG), - workers=workers, - ).run() + if isinstance(app_obj, FastStream): + Multiprocess( + target=_run, + args=(*args, logging.DEBUG), + workers=workers, + ).run() + else: + args[1]["workers"] = workers + _run(*args) else: _run(*args) @@ -165,9 +169,9 @@ def _run( if is_factory and callable(app_obj): app_obj = app_obj() - if not isinstance(app_obj, FastStream): + if not isinstance(app_obj, Application): raise typer.BadParameter( - f'Imported object "{app_obj}" must be "FastStream" type.', + f'Imported object "{app_obj}" must be "Application" type.', ) if log_level > 0: diff --git a/tests/cli/rabbit/test_app.py b/tests/cli/rabbit/test_app.py index 4765ffac9a..68fc2517e7 100644 --- a/tests/cli/rabbit/test_app.py +++ b/tests/cli/rabbit/test_app.py @@ -239,34 +239,6 @@ async def lifespan(env: str): mock.off.assert_called_once() -@pytest.mark.asyncio -@pytest.mark.skipif(IS_WINDOWS, reason="does not run on windows") -async def test_stop_with_sigint(async_mock, app: FastStream): - with patch.object(app.broker, "start", async_mock.broker_run_sigint), patch.object( - app.broker, "close", async_mock.broker_stopped_sigint - ): - async with anyio.create_task_group() as tg: - tg.start_soon(app.run) - tg.start_soon(_kill, signal.SIGINT) - - async_mock.broker_run_sigint.assert_called_once() - async_mock.broker_stopped_sigint.assert_called_once() - - -@pytest.mark.asyncio -@pytest.mark.skipif(IS_WINDOWS, reason="does not run on windows") -async def test_stop_with_sigterm(async_mock, app: FastStream): - with patch.object(app.broker, "start", async_mock.broker_run_sigterm), patch.object( - app.broker, "close", async_mock.broker_stopped_sigterm - ): - async with anyio.create_task_group() as tg: - tg.start_soon(app.run) - tg.start_soon(_kill, signal.SIGTERM) - - async_mock.broker_run_sigterm.assert_called_once() - async_mock.broker_stopped_sigterm.assert_called_once() - - @pytest.mark.asyncio async def test_test_app(mock: Mock): app = FastStream() @@ -364,5 +336,57 @@ async def lifespan(env: str): async_mock.broker_stopped.assert_called_once() +@pytest.mark.asyncio +@pytest.mark.skipif(IS_WINDOWS, reason="does not run on windows") +async def test_stop_with_sigint(async_mock, app: FastStream): + with patch.object(app.broker, "start", async_mock.broker_run_sigint), patch.object( + app.broker, "close", async_mock.broker_stopped_sigint + ): + async with anyio.create_task_group() as tg: + tg.start_soon(app.run) + tg.start_soon(_kill, signal.SIGINT) + + async_mock.broker_run_sigint.assert_called_once() + async_mock.broker_stopped_sigint.assert_called_once() + + +@pytest.mark.asyncio +@pytest.mark.skipif(IS_WINDOWS, reason="does not run on windows") +async def test_stop_with_sigterm(async_mock, app: FastStream): + with patch.object(app.broker, "start", async_mock.broker_run_sigterm), patch.object( + app.broker, "close", async_mock.broker_stopped_sigterm + ): + async with anyio.create_task_group() as tg: + tg.start_soon(app.run) + tg.start_soon(_kill, signal.SIGTERM) + + async_mock.broker_run_sigterm.assert_called_once() + async_mock.broker_stopped_sigterm.assert_called_once() + + +@pytest.mark.asyncio +@pytest.mark.skipif(IS_WINDOWS, reason="does not run on windows") +async def test_run_asgi(async_mock: AsyncMock, app: FastStream): + asgi_routes = [("/", lambda scope, receive, send: None)] + asgi_app = app.as_asgi(asgi_routes=asgi_routes) + assert asgi_app.broker is app.broker + assert asgi_app.logger is app.logger + assert asgi_app.lifespan_context is app.lifespan_context + assert asgi_app._on_startup_calling is app._on_startup_calling + assert asgi_app._after_startup_calling is app._after_startup_calling + assert asgi_app._on_shutdown_calling is app._on_shutdown_calling + assert asgi_app._after_shutdown_calling is app._after_shutdown_calling + assert asgi_app.routes == asgi_routes + + with patch.object(app.broker, "start", async_mock.broker_run), patch.object( + app.broker, "close", async_mock.broker_stopped + ): + async with anyio.create_task_group() as tg: + tg.start_soon(app.run) + tg.start_soon(_kill, signal.SIGINT) + + async_mock.broker_run.assert_called_once() + + async def _kill(sig): os.kill(os.getpid(), sig) diff --git a/tests/cli/test_run.py b/tests/cli/test_run.py new file mode 100644 index 0000000000..7bff4497f7 --- /dev/null +++ b/tests/cli/test_run.py @@ -0,0 +1,85 @@ +import logging +from unittest.mock import AsyncMock, Mock, patch + +import pytest +from typer.testing import CliRunner + +from faststream.asgi import AsgiFastStream +from faststream.cli.main import cli as faststream_app + + +def test_run_as_asgi(runner: CliRunner): + app = AsgiFastStream() + app.run = AsyncMock() + + with patch("faststream.cli.main.import_from_string", return_value=(None, app)): + result = runner.invoke( + faststream_app, + [ + "run", + "faststream:app", + "--host", + "0.0.0.0", + "--port", + "8000", + ], + ) + app.run.assert_awaited_once_with( + logging.INFO, {"host": "0.0.0.0", "port": "8000"} + ) + assert result.exit_code == 0 + + +@pytest.mark.parametrize("workers", [1, 2, 5]) +def test_run_as_asgi_with_workers(runner: CliRunner, workers: int): + app = AsgiFastStream() + app.run = AsyncMock() + + with patch("faststream.cli.main.import_from_string", return_value=(None, app)): + result = runner.invoke( + faststream_app, + [ + "run", + "faststream:app", + "--host", + "0.0.0.0", + "--port", + "8000", + "--workers", + str(workers), + ], + ) + extra = {"workers": workers} if workers > 1 else {} + + app.run.assert_awaited_once_with( + logging.INFO, {"host": "0.0.0.0", "port": "8000", **extra} + ) + assert result.exit_code == 0 + + +def test_run_as_asgi_callable(runner: CliRunner): + app = AsgiFastStream() + app.run = AsyncMock() + + app_factory = Mock(return_value=app) + + with patch( + "faststream.cli.main.import_from_string", return_value=(None, app_factory) + ): + result = runner.invoke( + faststream_app, + [ + "run", + "faststream:app", + "--host", + "0.0.0.0", + "--port", + "8000", + "--factory", + ], + ) + app_factory.assert_called_once() + app.run.assert_awaited_once_with( + logging.INFO, {"host": "0.0.0.0", "port": "8000"} + ) + assert result.exit_code == 0