diff --git a/pywa/client.py b/pywa/client.py index cff83ba9..093c902c 100644 --- a/pywa/client.py +++ b/pywa/client.py @@ -13,14 +13,13 @@ import mimetypes import os import pathlib -import threading import warnings from types import NoneType from typing import BinaryIO, Iterable, Literal, Any, Callable import requests -from . import utils, errors +from . import utils from .api import WhatsAppCloudApi from .handlers import Handler, HandlerDecorators, FlowRequestHandler # noqa from .types import ( @@ -219,58 +218,6 @@ def _setup_api( api_version=api_version, ) - def _delayed_register_callback_url( - self, - callback_url: str, - app_id: int, - app_secret: str, - verify_token: str, - fields: tuple[str, ...] | None, - delay: int, - ) -> None: - threading.Timer( - interval=delay, - function=self._register_callback_url, - kwargs=dict( - callback_url=callback_url, - app_id=app_id, - app_secret=app_secret, - verify_token=verify_token, - fields=fields, - ), - ).start() - - def _register_callback_url( - self, - callback_url: str, - app_id: int, - app_secret: str, - verify_token: str, - fields: tuple[str, ...] | None, - ) -> None: - """ - This is a non-blocking function that registers the callback URL. - It must be called after the server is running so that the challenge can be verified. - """ - try: - app_access_token = self.api.get_app_access_token( - app_id=app_id, app_secret=app_secret - ) - # noinspection PyProtectedMember - if not self.api.set_app_callback_url( - app_id=app_id, - app_access_token=app_access_token["access_token"], - callback_url=callback_url, - verify_token=verify_token, - fields=fields, - )["success"]: - raise RuntimeError("Failed to register callback URL.") - _logger.info("Callback URL '%s' registered successfully", callback_url) - except errors.WhatsAppError as e: - raise RuntimeError( - f"Failed to register callback URL '{callback_url}'" - ) from e - def __str__(self) -> str: return f"WhatsApp(phone_id={self.phone_id!r})" diff --git a/pywa/server.py b/pywa/server.py index eff7f3a2..502866eb 100644 --- a/pywa/server.py +++ b/pywa/server.py @@ -4,11 +4,11 @@ import asyncio import logging +import threading from concurrent.futures import ThreadPoolExecutor -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING, Callable, Any, Tuple, Coroutine -from . import utils, handlers -from .errors import WhatsAppError +from . import utils, handlers, errors from .handlers import Handler, ChatOpenedHandler, TemplateStatusHandler # noqa from .handlers import ( CallbackButtonHandler, @@ -106,6 +106,7 @@ def __init__( self._server = None return self._server = server + self._verify_token = verify_token self._server_type = utils.ServerType.from_app(server) self._executor = ThreadPoolExecutor(max_workers, thread_name_prefix="Handler") self._loop = asyncio.get_event_loop() @@ -116,13 +117,14 @@ def __init__( self._flows_response_encryptor = flows_response_encryptor self._continue_handling = continue_handling self._skip_duplicate_updates = skip_duplicate_updates + self._updates_ids_in_process = set[str]() if not verify_token: raise ValueError( "When listening for incoming updates, a verify token must be provided.\n>> The verify token can " "be any string. It is used to challenge the webhook endpoint to verify that the endpoint is valid." ) - self._register_routes(verify_token=verify_token) + self._register_routes() if callback_url is not None: if app_id is None or app_secret is None: @@ -143,52 +145,50 @@ def __init__( else 0, ) - def _register_routes(self: "WhatsApp", verify_token: str) -> None: - hub_vt = "hub.verify_token" - hub_ch = "hub.challenge" - - async def challenge_handler(vt: str, ch: str) -> tuple[str, int]: - """The challenge function that is called when the callback URL is registered.""" - if vt == verify_token: - _logger.info( - "Webhook ('%s') passed the verification challenge", - self._webhook_endpoint, - ) - return ch, 200 - _logger.error( - "Webhook ('%s') failed the verification challenge. Expected verify_token: %s, received: %s", + async def webhook_challenge_handler(self, vt: str, ch: str) -> tuple[str, int]: + """The challenge function that is called when the callback URL is registered.""" + if vt == self._verify_token: + _logger.info( + "Webhook ('%s') passed the verification challenge", self._webhook_endpoint, - verify_token, - vt, ) - return "Error, invalid verification token", 403 - - self._updates_ids_in_process = set[str]() - - async def webhook_update_handler(update: dict) -> tuple[str, int]: - """The webhook function that is called when an update is received.""" - update_id: str | None = None - _logger.debug( - "Webhook ('%s') received an update: %s", - self._webhook_endpoint, - update, - ) - if self._skip_duplicate_updates and ( - update_id := _extract_id_from_update(update) - ): - if update_id in self._updates_ids_in_process: - _logger.warning( - "Webhook ('%s') received an update with an ID that is already being processed: %s", - self._webhook_endpoint, - update_id, - ) - return "ok", 200 - self._updates_ids_in_process.add(update_id) - await self._call_handlers(update) - if self._skip_duplicate_updates and update_id is not None: - if update_id is not None: - self._updates_ids_in_process.remove(update_id) - return "ok", 200 + return ch, 200 + _logger.error( + "Webhook ('%s') failed the verification challenge. Expected verify_token: %s, received: %s", + self._webhook_endpoint, + self._verify_token, + vt, + ) + return "Error, invalid verification token", 403 + + async def webhook_update_handler(self, update: dict) -> tuple[str, int]: + """The webhook function that is called when an update is received.""" + update_id: str | None = None + _logger.debug( + "Webhook ('%s') received an update: %s", + self._webhook_endpoint, + update, + ) + if self._skip_duplicate_updates and ( + update_id := _extract_id_from_update(update) + ): + if update_id in self._updates_ids_in_process: + _logger.warning( + "Webhook ('%s') received an update with an ID that is already being processed: %s", + self._webhook_endpoint, + update_id, + ) + return "ok", 200 + self._updates_ids_in_process.add(update_id) + await self._call_handlers(update) + if self._skip_duplicate_updates and update_id is not None: + if update_id is not None: + self._updates_ids_in_process.remove(update_id) + return "ok", 200 + + def _register_routes(self: "WhatsApp") -> None: + hub_vt = "hub.verify_token" + hub_ch = "hub.challenge" match self._server_type: case utils.ServerType.FLASK: @@ -199,7 +199,7 @@ async def webhook_update_handler(update: dict) -> tuple[str, int]: @self._server.route(self._webhook_endpoint, methods=["GET"]) @utils.rename_func(f"({self.phone_id})") async def flask_challenge() -> tuple[str, int]: - return await challenge_handler( + return await self.webhook_challenge_handler( vt=flask.request.args.get(hub_vt), ch=flask.request.args.get(hub_ch), ) @@ -207,7 +207,7 @@ async def flask_challenge() -> tuple[str, int]: @self._server.route(self._webhook_endpoint, methods=["POST"]) @utils.rename_func(f"({self.phone_id})") async def flask_webhook() -> tuple[str, int]: - return await webhook_update_handler(flask.request.json) + return await self.webhook_update_handler(flask.request.json) else: # flask @@ -215,18 +215,18 @@ async def flask_webhook() -> tuple[str, int]: @utils.rename_func(f"({self.phone_id})") def flask_challenge() -> tuple[str, int]: return self._loop.run_until_complete( - challenge_handler( + self.webhook_challenge_handler( vt=flask.request.args.get(hub_vt), ch=flask.request.args.get(hub_ch), ) ) - @self._server.route(self._webhook_endpoint, methods=["POST"]) - @utils.rename_func(f"({self.phone_id})") - def flask_webhook() -> tuple[str, int]: - return self._loop.run_until_complete( - webhook_update_handler(flask.request.json) - ) + @self._server.route(self._webhook_endpoint, methods=["POST"]) + @utils.rename_func(f"({self.phone_id})") + def flask_webhook() -> tuple[str, int]: + return self._loop.run_until_complete( + self.webhook_update_handler(flask.request.json) + ) case utils.ServerType.FASTAPI: import fastapi @@ -234,7 +234,7 @@ def flask_webhook() -> tuple[str, int]: @self._server.get(self._webhook_endpoint) @utils.rename_func(f"({self.phone_id})") async def fastapi_challenge(req: fastapi.Request) -> fastapi.Response: - content, status_code = await challenge_handler( + content, status_code = await self.webhook_challenge_handler( vt=req.query_params.get(hub_vt), ch=req.query_params.get(hub_ch) ) return fastapi.Response(content=content, status_code=status_code) @@ -242,7 +242,7 @@ async def fastapi_challenge(req: fastapi.Request) -> fastapi.Response: @self._server.post(self._webhook_endpoint) @utils.rename_func(f"({self.phone_id})") async def fastapi_webhook(req: fastapi.Request) -> fastapi.Response: - content, status_code = await webhook_update_handler( + content, status_code = await self.webhook_update_handler( await req.json() ) return fastapi.Response(content=content, status_code=status_code) @@ -251,36 +251,6 @@ async def fastapi_webhook(req: fastapi.Request) -> fastapi.Response: f"The `server` must be one of {utils.ServerType.protocols_names()}" ) - def _register_callback_url( - self: "WhatsApp", - callback_url: str, - app_id: int, - app_secret: str, - verify_token: str, - fields: tuple[str, ...] | None, - ) -> None: - """ - This is a non-blocking function that registers the callback URL. - It must be called after the server is running so that the challenge can be verified. - """ - full_url = f"{callback_url.rstrip('/')}/{self._webhook_endpoint.lstrip('/')}" - try: - app_access_token = self.api.get_app_access_token( - app_id=app_id, app_secret=app_secret - ) - # noinspection PyProtectedMember - if not self.api.set_app_callback_url( - app_id=app_id, - app_access_token=app_access_token["access_token"], - callback_url=full_url, - verify_token=verify_token, - fields=tuple(fields or Handler._fields_to_subclasses().keys()), - )["success"]: - raise RuntimeError("Failed to register callback URL.") - _logger.info("Callback URL '%s' registered successfully", full_url) - except WhatsAppError as e: - raise RuntimeError(f"Failed to register callback URL '{full_url}'") from e - async def _call_handlers(self: "WhatsApp", update: dict) -> None: """Call the handlers for the given update.""" try: @@ -383,7 +353,59 @@ def _get_handler(self: "WhatsApp", update: dict) -> type[Handler] | None: # noinspection PyProtectedMember return Handler._fields_to_subclasses().get(field) - def _register_flow_endpoint_callback( + def _delayed_register_callback_url( + self: "WhatsApp", + callback_url: str, + app_id: int, + app_secret: str, + verify_token: str, + fields: tuple[str, ...] | None, + delay: int, + ) -> None: + threading.Timer( + interval=delay, + function=self._register_callback_url, + kwargs=dict( + callback_url=callback_url, + app_id=app_id, + app_secret=app_secret, + verify_token=verify_token, + fields=fields, + ), + ).start() + + def _register_callback_url( + self: "WhatsApp", + callback_url: str, + app_id: int, + app_secret: str, + verify_token: str, + fields: tuple[str, ...] | None, + ) -> None: + """ + This is a non-blocking function that registers the callback URL. + It must be called after the server is running so that the challenge can be verified. + """ + try: + app_access_token = self.api.get_app_access_token( + app_id=app_id, app_secret=app_secret + ) + # noinspection PyProtectedMember + if not self.api.set_app_callback_url( + app_id=app_id, + app_access_token=app_access_token["access_token"], + callback_url=callback_url, + verify_token=verify_token, + fields=fields, + )["success"]: + raise RuntimeError("Failed to register callback URL.") + _logger.info("Callback URL '%s' registered successfully", callback_url) + except errors.WhatsAppError as e: + raise RuntimeError( + f"Failed to register callback URL '{callback_url}'" + ) from e + + def get_flow_request_handler( self: "WhatsApp", endpoint: str, callback: handlers._FlowRequestHandlerT, @@ -393,13 +415,7 @@ def _register_flow_endpoint_callback( private_key_password: str | None, request_decryptor: utils.FlowRequestDecryptor | None, response_encryptor: utils.FlowResponseEncryptor | None, - ) -> None: - """Internal function to register a flow endpoint callback.""" - if self._server is None: - raise ValueError( - "You must initialize the WhatsApp client with an web server" - " (Flask or FastAPI) in order to handle incoming flow requests." - ) + ) -> Callable[[dict], Coroutine[Any, Any, tuple[str, int]]]: private_key = private_key or self._private_key private_key_password = private_key_password or self._private_key_password if not private_key: @@ -434,7 +450,7 @@ def _register_flow_endpoint_callback( ) async def flow_request_handler(payload: dict) -> tuple[str, int]: - """Called by the server when a flow request is received. Returns response and status code.""" + """Callback function that handles the incoming flow requests.""" try: decrypted_request, aes_key, iv = request_decryptor( payload["encrypted_flow_data"], @@ -524,6 +540,37 @@ async def flow_request_handler(payload: dict) -> tuple[str, int]: iv, ), 200 + return flow_request_handler + + def _register_flow_endpoint_callback( + self: "WhatsApp", + endpoint: str, + callback: handlers._FlowRequestHandlerT, + acknowledge_errors: bool, + handle_health_check: bool, + private_key: str | None, + private_key_password: str | None, + request_decryptor: utils.FlowRequestDecryptor | None, + response_encryptor: utils.FlowResponseEncryptor | None, + ) -> None: + """Internal function to register a flow endpoint callback.""" + if self._server is None: + raise ValueError( + "You must initialize the WhatsApp client with an web server" + " (Flask or FastAPI) in order to handle incoming flow requests." + ) + + handler = self.get_flow_request_handler( + endpoint=endpoint, + callback=callback, + acknowledge_errors=acknowledge_errors, + handle_health_check=handle_health_check, + private_key=private_key, + private_key_password=private_key_password, + request_decryptor=request_decryptor, + response_encryptor=response_encryptor, + ) + match self._server_type: case utils.ServerType.FLASK: import flask @@ -533,14 +580,14 @@ async def flow_request_handler(payload: dict) -> tuple[str, int]: @self._server.route(endpoint, methods=["POST"]) @utils.rename_func(f"({endpoint})") async def flask_flow() -> tuple[str, int]: - return await flow_request_handler(flask.request.json) + return await handler(flask.request.json) else: @self._server.route(endpoint, methods=["POST"]) @utils.rename_func(f"({endpoint})") def flask_flow() -> tuple[str, int]: return self._loop.run_until_complete( - flow_request_handler(flask.request.json) + handler(flask.request.json) ) case utils.ServerType.FASTAPI: import fastapi @@ -548,7 +595,7 @@ def flask_flow() -> tuple[str, int]: @self._server.post(endpoint) @utils.rename_func(f"({endpoint})") async def fastapi_flow(req: fastapi.Request) -> fastapi.Response: - response, status_code = await flow_request_handler(await req.json()) + response, status_code = await handler(await req.json()) return fastapi.Response( content=response, status_code=status_code,