diff --git a/firebase_messaging/fcmpushclient.py b/firebase_messaging/fcmpushclient.py index 67bc20e..01d52f2 100644 --- a/firebase_messaging/fcmpushclient.py +++ b/firebase_messaging/fcmpushclient.py @@ -10,7 +10,7 @@ from threading import Thread from typing import Any, Callable, Optional, List from dataclasses import dataclass - +from enum import Enum from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.serialization import load_der_private_key from google.protobuf.json_format import MessageToJson @@ -40,6 +40,13 @@ _logger = logging.getLogger(__name__) +class ErrorType(Enum): + CONNECTION = 1 + READ = 2 + LOGIN = 3 + NOTIFY = 4 + + @dataclass class FcmPushClientConfig: # pylint:disable=too-many-instance-attributes """Class to provide configuration to @@ -129,8 +136,8 @@ def __init__( self.is_stopping = False self.tasks = None - self.loop = None - self.main_loop = None + self.listen_event_loop = None + self.callback_event_loop = None self.fcm_thread = None self.app_id = None @@ -160,7 +167,11 @@ def _log_warn_with_limit(self, msg: str, *args): async def _do_writer_close(self): try: - if self.loop and self.writer and self.loop.is_running(): + if ( + self.listen_event_loop + and self.writer + and self.listen_event_loop.is_running() + ): self.writer.close() await self.writer.wait_closed() except OSError as e: @@ -323,7 +334,7 @@ async def _login(self): _logger.debug("Sent login request") except Exception as ex: _logger.error("Received an exception logging in: %s", ex) - if self._try_increment_error_count("login"): + if self._try_increment_error_count(ErrorType.LOGIN): await self._reset() @staticmethod @@ -354,17 +365,17 @@ def _app_data_by_key(self, p, key): raise RuntimeError("couldn't find in app_data {}".format(key)) - def _handle_data_message(self, callback, p, obj): + def _handle_data_message(self, callback, msg, obj): _logger.debug( "Received data message Stream ID: %s, Last: %s, Status: %s", - p.stream_id, - p.last_stream_id_received, - p.status, + msg.stream_id, + msg.last_stream_id_received, + msg.status, ) - crypto_key = self._app_data_by_key(p, "crypto-key")[3:] # strip dh= - salt = self._app_data_by_key(p, "encryption")[5:] # strip salt= - subtype = self._app_data_by_key(p, "subtype") + crypto_key = self._app_data_by_key(msg, "crypto-key")[3:] # strip dh= + salt = self._app_data_by_key(msg, "encryption")[5:] # strip salt= + subtype = self._app_data_by_key(msg, "subtype") if subtype != self.app_id: self._log_warn_with_limit( "Subtype %s in data message does not match" @@ -373,7 +384,7 @@ def _handle_data_message(self, callback, p, obj): self.app_id, ) decrypted = self._decrypt_raw_data( - self.credentials, crypto_key, salt, p.raw_data + self.credentials, crypto_key, salt, msg.raw_data ) try: decrypted_json = json.loads(decrypted.decode("utf-8")) @@ -382,25 +393,37 @@ def _handle_data_message(self, callback, p, obj): ret_val = decrypted_json if decrypted_json else decrypted self._log_verbose( - "Decrypted data for message %s is: %s", p.persistent_id, ret_val + "Decrypted data for message %s is: %s", msg.persistent_id, ret_val ) - if self.main_loop and self.main_loop.is_running(): - if callback: - on_error = functools.partial(self._try_increment_error_count, "notify") - self.main_loop.call_soon_threadsafe( + if callback and self.listen_event_loop != self.callback_event_loop: + if callback and self.callback_event_loop.is_running(): + on_error = functools.partial( + self._try_increment_error_count, ErrorType.NOTIFY + ) + on_success = functools.partial( + self._reset_error_count, ErrorType.NOTIFY + ) + self.callback_event_loop.call_soon_threadsafe( functools.partial( FcmPushClient._wrapped_callback, - self.loop, + self.listen_event_loop, on_error, + on_success, callback, ret_val, - p, + msg.persistent_id, obj, ) ) - else: - _logger.debug("Main loop no longer running, terminating FcmClient") - self._terminate() + elif callback: + try: + callback(ret_val, msg.persistent_id, obj) + self._reset_error_count(ErrorType.NOTIFY) + except Exception as ex: + _logger.error( + "Unexpected exception calling notification callback: %s", ex + ) + self._try_increment_error_count(ErrorType.NOTIFY) def _new_input_stream_id_available(self): return self.last_input_stream_id_reported != self.input_stream_id @@ -465,12 +488,12 @@ def _terminate(self): ): # cancel return if task is done so no need to check task.cancel() - async def _do_monitor(self): + async def _do_monitor(self, callback): while self.do_listen: await asyncio.sleep(self.config.monitor_interval) - if not self.main_loop or not self.main_loop.is_running(): - _logger.debug("Main loop no longer running, terminating FcmClient") + if callback and not self.callback_event_loop.is_running(): + _logger.debug("Callback loop no longer running, terminating FcmClient") self._terminate() return @@ -490,7 +513,10 @@ async def _do_monitor(self): ): await self._reset() - def _try_increment_error_count(self, error_type: str): + def _reset_error_count(self, error_type: ErrorType): + self.sequential_error_counters[error_type] = 0 + + def _try_increment_error_count(self, error_type: ErrorType): if error_type not in self.sequential_error_counters: self.sequential_error_counters[error_type] = 0 @@ -516,18 +542,18 @@ async def _handle_message(self, msg, callback, obj): if isinstance(msg, Close): self._log_warn_with_limit("Server sent Close message, resetting") - if self._try_increment_error_count("connection"): + if self._try_increment_error_count(ErrorType.CONNECTION): await self._reset() return if isinstance(msg, LoginResponse): if str(msg.error): _logger.error("Received login error response: %s", msg) - if self._try_increment_error_count("login"): + if self._try_increment_error_count(ErrorType.LOGIN): await self._reset() else: _logger.info("Succesfully logged in to MCS endpoint") - self.sequential_error_counters["login"] = 0 + self._reset_error_count(ErrorType.LOGIN) self.logged_in = True self.persistent_ids = [] return @@ -546,8 +572,8 @@ async def _handle_message(self, msg, callback, obj): else: self._log_warn_with_limit("Unexpected message type %s.", type(msg).__name__) # Reset error count if a read has been succesful - self.sequential_error_counters["read"] = 0 - self.sequential_error_counters["connection"] = 0 + self._reset_error_count(ErrorType.READ) + self._reset_error_count(ErrorType.CONNECTION) @staticmethod async def _open_connection(host, port, ssl): @@ -610,9 +636,11 @@ async def _listen(self, callback, obj=None): # pylint: disable=too-many-branche try: await self._login() - while self.do_listen and self.loop.is_running(): - if not self.main_loop or not self.main_loop.is_running(): - _logger.debug("Main loop no longer running, terminating FcmClient") + while self.do_listen and self.listen_event_loop.is_running(): + if callback and not self.callback_event_loop.is_running(): + _logger.debug( + "Callback loop no longer running, terminating FcmClient" + ) self._terminate() return try: @@ -623,7 +651,15 @@ async def _listen(self, callback, obj=None): # pylint: disable=too-many-branche except OSError as osex: if ( - isinstance(osex, (ConnectionResetError, TimeoutError, SSLError)) + isinstance( + osex, + ( + ConnectionResetError, + TimeoutError, + asyncio.IncompleteReadError, + SSLError, + ), + ) and self.is_resetting ): if ( @@ -641,7 +677,7 @@ async def _listen(self, callback, obj=None): # pylint: disable=too-many-branche ) else: _logger.error("Unexpected exception during read: %s", osex) - if self._try_increment_error_count("connection"): + if self._try_increment_error_count(ErrorType.CONNECTION): await self._reset() except asyncio.CancelledError as cex: @@ -660,31 +696,42 @@ def _signal_handler(self): self.disconnect() async def _run_tasks(self, callback, obj): + self.reset_lock = asyncio.Lock() + self.stopping_lock = asyncio.Lock() + self.do_listen = True try: self.tasks = [ asyncio.create_task(self._listen(callback, obj)), - asyncio.create_task(self._do_monitor()), + asyncio.create_task(self._do_monitor(callback)), ] - return await asyncio.gather(*self.tasks, return_exceptions=True) + await asyncio.gather(*self.tasks, return_exceptions=True) + _logger.info("FCMClient has shutdown") except Exception as ex: _logger.error("Unexpected error running FcmPushClient: %s", ex) - def _start_connection(self, callback, obj): - self.do_listen = True - self.loop = asyncio.new_event_loop() + def _start_on_new_loop(self, callback, obj): + self.listen_event_loop = asyncio.new_event_loop() + if not self.callback_event_loop: + self.callback_event_loop = self.listen_event_loop - asyncio.set_event_loop(self.loop) - self.loop.run_until_complete(self._run_tasks(callback, obj)) - _logger.info("FCMClient has shutdown") + asyncio.set_event_loop(self.listen_event_loop) + self.listen_event_loop.run_until_complete(self._run_tasks(callback, obj)) @staticmethod def _wrapped_callback( - fcm_client_loop, on_error, callback, notification, msg, obj + fcm_client_loop, + on_error, + on_success, + callback, + notification, + persistent_id, + obj, ): # pylint: disable=too-many-arguments - # Should be running on main loop + # Should be running on callback loop try: - callback(notification, msg.persistent_id, obj) + callback(notification, persistent_id, obj) + fcm_client_loop.call_soon_threadsafe(on_success) except Exception as ex: _logger.error("Unexpected exception calling notification callback: %s", ex) fcm_client_loop.call_soon_threadsafe(on_error) @@ -715,7 +762,12 @@ def checkin(self, sender_id: int, app_id: str) -> str: return self.credentials["fcm"]["token"] def connect( - self, callback: Optional[Callable[[dict, str, Optional[Any]], None]], obj=None + self, + callback: Optional[Callable[[dict, str, Optional[Any]], None]], + obj: Any = None, + *, + listen_event_loop: asyncio.AbstractEventLoop = None, + callback_event_loop: asyncio.AbstractEventLoop = None, ): """Connect to FCM and start listening for push messages on a seperate service thread. @@ -727,23 +779,28 @@ def connect( persistent_id: unique message identifier from the FCM server.\n obj: returns the arbitrary object if supplied to this function. :param obj: Arbitrary object to be returned in the callback. + :param listen_event_loop: If supplied the client will use this event loop + for asyncio communication with the fcm server, otherwise it will create + it's own thread and start an event loop on it. + :param callback_event_loop: If supplied the client will run the callback + on the supplied loop, otherwise it will run the callback on it's own + thread loop or the listen_event_loop if set. """ - try: - self.main_loop = asyncio.get_running_loop() - except RuntimeError: - _logger.error( - "No running event loop, connect failed. " - + "FcMPushClient needs a running event loop to call back on" - ) - return - - self.reset_lock = asyncio.Lock() - self.stopping_lock = asyncio.Lock() + self.listen_event_loop = listen_event_loop + self.callback_event_loop = callback_event_loop - self.fcm_thread = Thread( - target=self._start_connection, args=[callback, obj], daemon=True - ) - self.fcm_thread.start() + if self.listen_event_loop: + if not self.callback_event_loop: + self.callback_event_loop = self.listen_event_loop + self.listen_event_loop.create_task(self._run_tasks(callback, obj)) + else: + self.fcm_thread = Thread( + target=self._start_on_new_loop, + args=[callback, obj], + daemon=True, + name="FcmClientThread", + ) + self.fcm_thread.start() async def _stop_connection(self): if self.stopping_lock.locked() or self.is_stopping: @@ -761,14 +818,24 @@ async def _stop_connection(self): finally: self.is_stopping = False + self.fcm_thread = None + self.listen_event_loop = None def disconnect(self): """Disconnects from FCM and shuts down the service thread.""" - if self.loop and self.loop.is_running() and self.fcm_thread.is_alive(): - _logger.debug("Shutting down FCMClient") - asyncio.run_coroutine_threadsafe(self._stop_connection(), self.loop) + if self.fcm_thread: + if ( + self.listen_event_loop + and self.listen_event_loop.is_running() + and self.fcm_thread.is_alive() + ): + _logger.debug("Shutting down FCMClient") + asyncio.run_coroutine_threadsafe( + self._stop_connection(), self.listen_event_loop + ) - self.loop = None + elif self.listen_event_loop and self.listen_event_loop.is_running(): + self.listen_event_loop.create_task(self._stop_connection()) def register(self, sender_id: int, app_id: str) -> dict: """Register gcm and fcm tokens for sender_id. @@ -810,10 +877,15 @@ async def _send_data_message(self, raw_data, persistent_id): def send_message(self, raw_data, persistent_id): """Not implemented, does nothing atm.""" - asyncio.run_coroutine_threadsafe( - self._send_data_message(raw_data, persistent_id), self.loop - ) + if self.fcm_thread: + asyncio.run_coroutine_threadsafe( + self._send_data_message(raw_data, persistent_id), self.listen_event_loop + ) + else: + self.listen_event_loop.create_task( + self._send_data_message(raw_data, persistent_id) + ) def __del__(self): - if self.loop and self.loop.is_running(): + if self.listen_event_loop and self.listen_event_loop.is_running(): self.disconnect() diff --git a/tests/conftest.py b/tests/conftest.py index e68b7be..4cde05c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,7 @@ """Test configuration for the Ring platform.""" import pytest import requests_mock -import struct -import traceback +import threading import os import select from unittest.mock import MagicMock, DEFAULT, Mock, patch @@ -45,23 +44,33 @@ async def fake_mcs_endpoint(): ep.close() -@pytest.fixture() -async def logged_in_push_client(fake_mcs_endpoint, mocker, caplog): +@pytest.fixture(params=[None, "loop"], ids=["loop_created", "loop_provided"]) +async def logged_in_push_client(request, fake_mcs_endpoint, mocker, caplog): clients = {} caplog.set_level(logging.DEBUG) - async def _logged_in_push_client(credentials, msg_callback, callback_obj = None, *, supress_disconnect=False, **config_kwargs): + listen_loop = asyncio.get_running_loop() if request.param else None + async def _logged_in_push_client(credentials, msg_callback, callback_obj = None, callback_loop=None, *, supress_disconnect=False, **config_kwargs): config = FcmPushClientConfig(**config_kwargs) pr = FcmPushClient(credentials=credentials, config=config) pr.checkin(1234, 4321) - pr.connect(msg_callback, callback_obj) - + + cb_loop = asyncio.get_running_loop() if callback_loop else None + pr.connect(msg_callback, callback_obj, listen_event_loop=listen_loop, callback_event_loop=cb_loop) + msg = await fake_mcs_endpoint.get_message() lr = load_fixture_as_msg("login_response.json", LoginResponse) await fake_mcs_endpoint.put_message(lr) clients[pr] = supress_disconnect + + tc = 1 if listen_loop else 2 + assert len(threading.enumerate()) == tc + if listen_loop: + assert pr.listen_event_loop == asyncio.get_running_loop() + else: + assert pr.listen_event_loop != asyncio.get_running_loop() return pr yield _logged_in_push_client diff --git a/tests/fakes.py b/tests/fakes.py index 14f7acf..951a84f 100644 --- a/tests/fakes.py +++ b/tests/fakes.py @@ -12,14 +12,15 @@ class FakeMcsEndpoint(): def __init__(self): - self.read_queue = asyncio.Queue() self.connection_mock = patch("asyncio.open_connection", side_effect = self.open_connection, autospec=True) self.connection_mock.start() self.client_loop = None + self.init_loop = None self.client_writer = None self.client_reader = None + self.init_loop = asyncio.get_running_loop() def close(self): self.connection_mock.stop() @@ -39,18 +40,27 @@ async def wait_for_connection(self, timeout=10): async def put_message(self, message): await self.wait_for_connection() - asyncio.run_coroutine_threadsafe(self.client_reader.put_message(message), self.client_loop) + if self.init_loop != self.client_loop: + asyncio.run_coroutine_threadsafe(self.client_reader.put_message(message), self.client_loop) + else: + await self.client_reader.put_message(message) async def put_error(self, error): await self.wait_for_connection() - asyncio.run_coroutine_threadsafe(self.client_reader.put_error(error), self.client_loop) + if self.init_loop != self.client_loop: + asyncio.run_coroutine_threadsafe(self.client_reader.put_error(error), self.client_loop) + else: + await self.client_reader.put_error(error) async def get_message(self): await self.wait_for_connection() - fut = asyncio.run_coroutine_threadsafe(self.client_writer.get_message(), self.client_loop) - return fut.result() + if self.init_loop != self.client_loop: + fut = asyncio.run_coroutine_threadsafe(self.client_writer.get_message(), self.client_loop) + return fut.result() + else: + return await self.client_writer.get_message() class FakeReader: def __init__(self): diff --git a/tests/test_fcmpushclient.py b/tests/test_fcmpushclient.py index 32cd25f..24b486b 100644 --- a/tests/test_fcmpushclient.py +++ b/tests/test_fcmpushclient.py @@ -38,23 +38,27 @@ async def test_login(logged_in_push_client, fake_mcs_endpoint, mocker, caplog): assert len([record for record in caplog.records if record.levelname == "ERROR"]) == 0 assert "Succesfully logged in to MCS endpoint" in [record.message for record in caplog.records if record.levelname == "INFO"] -#@pytest.mark.parametrize("raw_data", [1,2,3,6]) -async def test_data_message_receive(logged_in_push_client, fake_mcs_endpoint, mocker, caplog): +@pytest.mark.parametrize("callback_loop", [None, "loop"], ids=["no_cb_loop_param", "cb_loop_param"]) +async def test_data_message_receive(logged_in_push_client, fake_mcs_endpoint, mocker, caplog, callback_loop): notification = None persistent_id = None callback_obj = None + cb_loop = None def on_msg(ntf, psid, obj=None): nonlocal notification nonlocal persistent_id nonlocal callback_obj + nonlocal cb_loop notification = ntf persistent_id = psid callback_obj = obj + cb_loop = asyncio.get_running_loop() credentials = load_fixture_as_dict("credentials.json") obj = "Foobar" - pr = await logged_in_push_client(credentials, on_msg, obj) + cb_loop_param = asyncio.get_running_loop() if callback_loop else None + pr = await logged_in_push_client(credentials, on_msg, obj, cb_loop_param) dms = load_fixture_as_msg("data_message_stanza.json", DataMessageStanza) await fake_mcs_endpoint.put_message(dms) @@ -66,6 +70,11 @@ def on_msg(ntf, psid, obj=None): assert notification == {'foo': 'bar'} assert persistent_id == dms.persistent_id assert obj == callback_obj + + if callback_loop: + assert cb_loop == asyncio.get_running_loop() + else: + assert cb_loop == pr.listen_event_loop async def test_connection_reset(logged_in_push_client, fake_mcs_endpoint, mocker): @@ -98,7 +107,7 @@ async def test_terminate(logged_in_push_client, fake_mcs_endpoint, mocker, error for i in range(1,error_count + 1): await fake_mcs_endpoint.put_error(ConnectionResetError()) - await asyncio.sleep(0.3) + await asyncio.sleep(0.1) # client should reset while it gets errors < abort_on_sequential_error_count then it should terminate if i < error_count: assert pr._reset.call_count == i @@ -116,7 +125,7 @@ async def test_heartbeat_receive(logged_in_push_client, fake_mcs_endpoint, caplo ping = load_fixture_as_msg("heartbeat_ping.json", HeartbeatPing) await fake_mcs_endpoint.put_message(ping) - await asyncio.sleep(0.1) + msg = await fake_mcs_endpoint.get_message() assert isinstance(msg, HeartbeatAck) @@ -129,25 +138,15 @@ async def test_heartbeat_send(logged_in_push_client, fake_mcs_endpoint, mocker, ping = load_fixture_as_msg("heartbeat_ping.json", HeartbeatPing) ack = load_fixture_as_msg("heartbeat_ack.json", HeartbeatAck) await pr._send_heartbeat() - await asyncio.sleep(0.1) + ping_msg = await fake_mcs_endpoint.get_message() - await asyncio.sleep(0.1) + await fake_mcs_endpoint.put_message(ack) await asyncio.sleep(0.1) assert isinstance(ping_msg, HeartbeatPing) assert len([record.message for record in caplog.records if record.levelname == "DEBUG" and "Received heartbeat ack" in record.message] ) == 1 -def test_no_loop(caplog): - - pr = FcmPushClient() - pr.connect(None) - - msg = ( - "No running event loop, connect failed. " + - "FcMPushClient needs a running event loop to call back on" - ) - assert len([record.message for record in caplog.records if record.levelname == "ERROR" and msg == record.message] ) == 1 async def test_decrypt(): def get_app_data_by_key(msg, key):