From d75459fb783c23c82ad1948c93faecf102eff835 Mon Sep 17 00:00:00 2001 From: mickolaua Date: Fri, 16 Aug 2024 19:30:44 +0300 Subject: [PATCH] #None: v0.4.21 (see README_NOTES for details) --- RELEASE_NOTES | 9 ++ aware/__version__.py | 2 +- aware/socket.py | 189 ++++++++++++++++++++++++++++++++--------- pyproject.toml | 3 +- tests/conftest.py | 5 +- tools/__init__.py | 7 ++ tools/notify_subs.py | 87 +++++++++++++++++++ tools/socket_client.py | 48 +++++++++++ 8 files changed, 305 insertions(+), 45 deletions(-) create mode 100644 tools/__init__.py create mode 100644 tools/notify_subs.py create mode 100644 tools/socket_client.py diff --git a/RELEASE_NOTES b/RELEASE_NOTES index f2b1510..77162af 100644 --- a/RELEASE_NOTES +++ b/RELEASE_NOTES @@ -204,6 +204,15 @@ not send with observation programs. existing options +## 0.4.21 + +- Fixed issue when several socket clients do not get all data +- Added disconnection of inactive clients in socket server +- Disable `aiomisc` logger in `pytest` config +- Added `notify_subs` tool to send message to Telegram subscribers +- Added `socket_client` tool, which is a simple socket client to receive messages + + # v0.3.0 Since this release, native Windows platform is not supported! diff --git a/aware/__version__.py b/aware/__version__.py index 3e26799..91ed962 100644 --- a/aware/__version__.py +++ b/aware/__version__.py @@ -1,2 +1,2 @@ -__version__ = (0, 4, 20) +__version__ = (0, 4, 21) __strversion__ = "{}.{}.{}".format(__version__) diff --git a/aware/socket.py b/aware/socket.py index 65960a9..9f837b9 100644 --- a/aware/socket.py +++ b/aware/socket.py @@ -26,7 +26,10 @@ hostname = CfgOption("hostname", "127.0.0.1", str, comment="Hostname of the server") port = CfgOption("port", 55555, int, comment="Port of the server") max_connections = CfgOption( - "max_connections", 5, int, comment="Maximum number of connections to the server" + "max_connections", + 5, + lambda x: max(1, x), + comment="Maximum number of connections to the server", ) send_alert_message = CfgOption( "send_alert_message", @@ -41,8 +44,10 @@ comment="Send messages on cancelled alerts via socket connection?", ) client_name_filters = CfgOption( - "client_name_filters", [IPV4_PORT_REGEX], lambda x: [re.compile(i) for i in x], - comment="Regular expression filters to validate client ip:port against" + "client_name_filters", + [IPV4_PORT_REGEX], + lambda x: [re.compile(i) for i in x], + comment="Regular expression filters to validate client ip:port against", ) @@ -81,37 +86,38 @@ async def form_message(data: AlertMessage | DataPackage) -> bytes: return b"" -async def client_task( - queue: asyncio.Queue, reader: asyncio.StreamReader, writer: asyncio.StreamWriter -): - client_addr = writer.get_extra_info("peername") - if is_allowed(*client_addr): - log.debug( - "client is connected from %s; client is in the whitelist", client_addr - ) - while True: - try: - data = await queue.get() +def is_allowed(ip: str, port: int) -> bool: + """Check if the given ip and port are allowed to connect to the server - if data: - msg = await form_message(data) - if msg: - log.debug("Sending observation plan to %s", client_addr) - log.debug("Plan: %s", msg) - writer.write(msg) - await writer.drain() + Parameters + ---------- + ip : str + an IP address + port : int + a port number - queue.task_done() + Returns + ------- + bool + True if the given ip and port are allowed + """ + allowed = False + for i, f in enumerate(client_name_filters.value): + if re.fullmatch(f, f"{ip}:{port}"): + allowed = True + break - except BaseException as e: - log.error("Error sending observation plan: %s", e) - break + return allowed - try: - writer.close() - await writer.wait_closed() - except BrokenPipeError as e: - log.error("Error occured at closing writer for %s: %s", client_addr, e) + +async def try_close_writer(writer: asyncio.StreamWriter, client_addr: tuple[str, int]): + try: + writer.close() + await writer.wait_closed() + except BrokenPipeError as e: + log.error( + "Error occured when closing writer for %s: %s", client_addr, e, exc_info=e + ) # Handle for alternative implementation of TCPClient @@ -161,26 +167,26 @@ async def client_task( # await self.server.serve_forever() -def is_allowed(ip: str, port: int): - allowed = False - for i, f in enumerate(client_name_filters.value): - if re.fullmatch(f, f"{ip}:{port}"): - allowed = True - break - - return allowed - - class SocketServer(TCPServer): def __init__( self, host: str = hostname.value, port: int = port.value, queue: asyncio.Queue = asyncio.Queue(), + max_connections: int = max_connections.value, **kwargs, ): super().__init__(address=host, port=port, **kwargs) self.queue = queue + self.max_connections = max_connections + self._connections = 0 + self._clients: dict[ + tuple[str, int], tuple[asyncio.StreamReader, asyncio.StreamWriter] + ] = {} + + @property + async def num_clients(self) -> int: + return len(self._clients) async def handle_client( self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter @@ -199,5 +205,106 @@ async def handle_client( ------- None """ - await asyncio.gather(client_task(self.queue, reader, writer)) + await asyncio.gather(self.client_task(reader, writer)) return None + + async def client_task( + self, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter, + ): + client_addr = writer.get_extra_info("peername") + if is_allowed(*client_addr): + log.debug( + "client is connecting from %s; client is in the whitelist", client_addr + ) + async with asyncio.Lock(): + if self._connections >= max_connections.value: + log.debug( + "Maximum number of connections reached: %d/%d", + self._connections, + max_connections.value, + ) + log.debug("Client will not be connected") + await try_close_writer(writer, client_addr) + return + + await self.add_client(client_addr, reader, writer) + + else: + log.debug( + "client is connecting from %s; client is not allowed", client_addr + ) + await try_close_writer(writer, client_addr) + + async def add_client( + self, + client_addr: tuple[str, int], + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter, + ): + async with asyncio.Lock(): + self._clients[client_addr] = (reader, writer) + log.debug("Added client %s:%d", *client_addr) + + async def remove_client(self, client_addr: tuple[str, int]): + async with asyncio.Lock(): + removed_client = self._clients.pop(client_addr, None) + if removed_client is not None: + log.debug("Removed client %s:%d", *client_addr) + else: + log.debug("Client %s:%d not found; nothing to remove", *client_addr) + + async def watch_clients(self): + """ + Watch for clients and remove unconnected ones. + """ + while True: + dead_clients = set() + + for addr, streams in self._clients.items(): + async with asyncio.Lock(): + reader, writer = streams + if writer.is_closing(): + dead_clients.add(addr) + log.debug("Found inactive client: %s", addr) + + for addr in dead_clients: + await self.remove_client(addr) + log.debug("Client removed due to inactivity: %s", addr) + + # Sleep here or the event loop will stuck + await asyncio.sleep(0.0) + + async def send_data(self): + """ + Send data to clients over socket connection. + """ + while True: + try: + data = await self.queue.get() + self.queue.task_done() + if data: + msg = await form_message(data) + + for client_addr, streams in self._clients.items(): + reader, writer = streams + if msg: + log.debug("Sending data to %s", client_addr) + log.debug("Data: %s", msg) + try: + writer.write(msg) + await writer.drain() + except Exception as e: + log.debug("Client %s error: %s", client_addr, e) + + await try_close_writer(writer, client_addr) + await self.remove_client(client_addr) + + except BaseException as e: + log.error("Error sending observation plan: %s", e) + finally: + await asyncio.sleep(0.0) + + async def start(self) -> None: + await asyncio.gather(super().start(), self.watch_clients(), self.send_data()) diff --git a/pyproject.toml b/pyproject.toml index 5846f71..b10789e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "AWARE" -version = "0.4.20" +version = "0.4.21" authors = [{name="Nicolai Pankov", email="colinsergesen@gmail.com"}] requires-python = ">=3.9,<3.12" dependencies = [ @@ -69,6 +69,7 @@ dev = [ "pytest-mock", "pytest-cov", "pytest-openfiles", + "pytest-timeout", "wheel", "requests-mock", "pytest-asyncio", diff --git a/tests/conftest.py b/tests/conftest.py index 4d1028f..f163846 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -49,5 +49,6 @@ def mocksession1(monkeypatch, tmp_path): @pytest.fixture(autouse=True) def no_logging(): - logger = logging.getLogger("aware") - logger.disabled = True + for name in {"aware", "aiomisc"}: + logger = logging.getLogger(name) + logger.disabled = True diff --git a/tools/__init__.py b/tools/__init__.py new file mode 100644 index 0000000..156b532 --- /dev/null +++ b/tools/__init__.py @@ -0,0 +1,7 @@ +""" +Author: Nicolai Pankov (colinsergesen@gmail.com) +__init__.py (c) 2024 +Desc: tools for developing and managing of AWARE +Created: 2024-08-14 +Modified: !date! +""" diff --git a/tools/notify_subs.py b/tools/notify_subs.py new file mode 100644 index 0000000..41990af --- /dev/null +++ b/tools/notify_subs.py @@ -0,0 +1,87 @@ +""" +Author: Nicolai Pankov (colinsergesen@gmail.com) +notify_subs.py (c) 2024 +Desc: notify subscribers about something +Created: 2024-08-13 +Modified: !date! +""" + +import asyncclick as click +from aware.logger import log +from aware.telegram.util import select_subscribers +import asyncio +from aware.telegram.bot import TOKEN +import aiohttp +import ssl + +# bot = Bot(TOKEN) +API_TELEGRAM_URL = "https://api.telegram.org/bot{}/sendMessage" + + +async def send_message( + chat_id: str, message: str, token: str +) -> aiohttp.ClientResponse: + sslcontext = ssl.create_default_context(purpose=ssl.Purpose.CLIENT_AUTH) + async with aiohttp.ClientSession() as session: + async with session.post( + url=API_TELEGRAM_URL.format(token), + headers={"Content-Type": "application/json"}, + params={ + "chat_id": chat_id, + "text": message, + }, + ) as response: + resp = await response.read() + if response.ok: + click.echo(f"Successfully sent message to {chat_id}") + else: + click.echo( + "Message was not send due to an error. See the server response: " + f"{resp}" + ) + return resp + + +@click.command("Notify subscribers") +@click.option("-m", "--message", help="Message to send to subscribers") +@click.option("-l", "--list-subscribers", help="List subscribers", is_flag=True) +@click.option("-s", "--subscriber", help="Subscriber ids", multiple=True) +async def main(message: str, list_subscribers: bool, subscriber: list[str]): + subs = select_subscribers() + if list_subscribers: + click.echo( + f"{'chat_id':^10} {'alert_type':^10} {'content_type':^10} " + f"{'telescopes':^10}" + ) + click.echo( + f"{'-'*len('chat_id'):^10} {'-'*len('alert_type'):^10} " + f"{'-'*len('content_type'):^10} {'-'*len('telescopes'):^10}" + ) + for subscriber in subs: + alerts = ",".join(subscriber["alert_type"]) + contents = ",".join(subscriber["content_type"]) + scopes = ",".join(subscriber["telescopes"]) + click.echo( + f'{str(subscriber["chat_id"]):^10} ' + f"{alerts:^10} " + f"{contents:^10} " + f"{scopes:^10}" + ) + else: + if message: + sub_ids = set([sub["chat_id"] for sub in subs]) + whitelist = set([int(s) for s in subscriber]) or sub_ids + for sub_id in whitelist: + if sub_id in sub_ids: + log.info("Sending message to subscriber with chat id %d", sub_id) + await send_message(sub_id, message, TOKEN) + else: + log.warning( + "Subscriber with chat id %d is absent in the database", sub_id + ) + else: + log.warning("Message is not provided") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/tools/socket_client.py b/tools/socket_client.py new file mode 100644 index 0000000..3bf2335 --- /dev/null +++ b/tools/socket_client.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +import asyncio + +import click as click +from aiomisc import entrypoint +from aiomisc.service import TCPClient + +async def readline(reader: asyncio.StreamReader) -> str: + data = await reader.readuntil() + return data.decode() + + +class SocketClient(TCPClient): + + async def handle_connection( + self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter + ): + try: + while True: + msg = await readline(reader) + click.echo(msg) + finally: + await writer.drain() + writer.close() + await writer.wait_closed() + + +def entry_point(host: str, port: int): + with entrypoint(SocketClient(host, port)) as loop: + try: + loop.run_forever() + except KeyboardInterrupt: + exit() + + +@click.command() +@click.argument("host", default="127.0.0.1") +@click.option("-p", "--port", type=int, default=55555, help="Default port") +def main(host: str, port: int): + """ + Listens host:port for alert messages and observation programs. + """ + entry_point(host, port) + + +if __name__ == "__main__": + main()