diff --git a/src/aragorn_app.py b/src/aragorn_app.py index 92ea0d9..dd564f8 100644 --- a/src/aragorn_app.py +++ b/src/aragorn_app.py @@ -15,7 +15,6 @@ from src.openapi_constructor import construct_open_api_schema from src.common import async_query, sync_query, status_query from src.default_queries import default_input_sync, default_input_async -from src.util import get_channel_pool from src.otel_config import configure_otel # declare the FastAPI details @@ -43,7 +42,7 @@ config["handlers"]["console"]["level"] = log_level config["handlers"]["file"]["level"] = log_level config["loggers"]["src"]["level"] = log_level -#config["loggers"]["aio_pika"]["level"] = log_level +config["loggers"]["aio_pika"]["level"] = log_level # load the log config logging.config.dictConfig(config) @@ -51,8 +50,10 @@ # create a logger logger = logging.getLogger(__name__) -# Get rabbitmq channel pool -channel_pool = get_channel_pool() +# Get rabbitmq connection +q_username = os.environ.get("QUEUE_USER", "guest") +q_password = os.environ.get("QUEUE_PW", "guest") +q_host = os.environ.get("QUEUE_HOST", "127.0.0.1") # declare the directory where the async data files will exist queue_file_dir = "./queue-files" @@ -111,13 +112,16 @@ async def subservice_callback(response: PDResponse, guid: str) -> int: """ # init the return html status code ret_val: int = 200 + # init the rbtmq connection + connection = None logger.debug(f"{guid}: Receiving sub-service callback") try: - async with channel_pool.acquire() as channel: + connection = await aio_pika.connect_robust(f"amqp://{q_username}:{q_password}@{q_host}/") + async with connection: + channel = await connection.channel() await channel.get_queue(guid, ensure=True) - # create a file path/name fname = "".join(random.choices(string.ascii_lowercase, k=12)) file_name = f"{queue_file_dir}/{guid}-{fname}-async-data.json" @@ -133,17 +137,15 @@ async def subservice_callback(response: PDResponse, guid: str) -> int: logger.debug(f"{guid}: Callback message published to queue.") else: logger.error(f"{guid}: Callback message publishing to queue failed, type: {type(publish_val)}") - # if isinstance(publish_val, spec.Basic.Ack): - # logger.info(f'{guid}: Callback message published to queue.') - # else: - # # set the html error code - # ret_val = 422 - # logger.error(f'{guid}: Callback message publishing to queue failed, type: {type(publish_val)}') except Exception as e: logger.exception(f"Exception detected while handling sub-service callback using guid {guid}", e) # set the html status code ret_val = 500 + finally: + # close rbtmq connection if it exists + if connection: + await connection.close() # return the response code return ret_val diff --git a/src/service_aggregator.py b/src/service_aggregator.py index beabfe4..7473ca7 100644 --- a/src/service_aggregator.py +++ b/src/service_aggregator.py @@ -1,4 +1,5 @@ """Literature co-occurrence support.""" +import aio_pika import json import logging import asyncio @@ -10,7 +11,7 @@ from string import Template from functools import partial -from src.util import create_log_entry, get_channel_pool +from src.util import create_log_entry from src.operations import sort_results_score, filter_results_top_n, filter_kgraph_orphans, filter_message_top_n from src.process_db import add_item from datetime import datetime @@ -27,9 +28,6 @@ logger = logging.getLogger(__name__) -# Get rabbitmq channel pool -channel_pool = get_channel_pool() - # declare the directory where the async data files will exist queue_file_dir = "./queue-files" @@ -280,25 +278,45 @@ async def collect_callback_responses(guid, num_queries): return responses +async def get_pika_connection(): + q_username = os.environ.get("QUEUE_USER", "guest") + q_password = os.environ.get("QUEUE_PW", "guest") + q_host = os.environ.get("QUEUE_HOST", "127.0.0.1") + connection = await aio_pika.connect_robust(host=q_host, login=q_username, password=q_password) + return connection + async def create_queue(guid): + connection = None try: - async with channel_pool.acquire() as channel: + connection = await get_pika_connection() + async with connection: + channel = await connection.channel() # declare the queue using the guid as the key queue = await channel.declare_queue(guid) except Exception as e: logger.error(f"{guid}: Failed to create queue.") raise e + finally: + if connection: + await connection.close() async def delete_queue(guid): + connection = None try: - async with channel_pool.acquire() as channel: - # declare the queue using the guid as the key + connection = await get_pika_connection() + async with connection: + channel = await connection.channel() + # delete the queue using the guid as the key queue = await channel.queue_delete(guid) except Exception: logger.error(f"{guid}: Failed to delete queue.") # Deleting queue isn't essential, so we will continue + finally: + if connection: + await connection.close() + def has_unique_nodes(result): """Given a result, return True if all nodes are unique, False otherwise""" @@ -332,8 +350,11 @@ async def check_for_messages(guid, num_queries, num_previously_received=0): responses = [] CONNECTION_TIMEOUT = 1 * 60 # 1 minutes num_responses = num_previously_received + connection = None try: - async with channel_pool.acquire() as channel: + connection = await get_pika_connection() + async with connection: + channel = await connection.channel() queue = await channel.get_queue(guid, ensure=True) # wait for the response. Timeout after async with queue.iterator(timeout=CONNECTION_TIMEOUT) as queue_iter: @@ -370,6 +391,9 @@ async def check_for_messages(guid, num_queries, num_previously_received=0): except Exception as e: logger.error(f"{guid}: Exception {e}. Returning {num_responses} results we have so far.") return responses, True + finally: + if connection: + await connection.close() return responses, complete diff --git a/src/util.py b/src/util.py index bcb9f6d..42fa766 100644 --- a/src/util.py +++ b/src/util.py @@ -1,10 +1,5 @@ """Common Aragorn Utilities.""" -import aio_pika -from aio_pika.abc import AbstractRobustConnection -from aio_pika.pool import Pool -import asyncio import datetime -import os def create_log_entry(msg: str, err_level, timestamp = datetime.datetime.now().isoformat(), code=None) -> dict: @@ -13,29 +8,3 @@ def create_log_entry(msg: str, err_level, timestamp = datetime.datetime.now().is # return to the caller return ret_val - - -def get_channel_pool(): - # get the queue connection params - q_username = os.environ.get("QUEUE_USER", "guest") - q_password = os.environ.get("QUEUE_PW", "guest") - q_host = os.environ.get("QUEUE_HOST", "127.0.0.1") - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - async def get_connection() -> AbstractRobustConnection: - return await aio_pika.connect_robust(f"amqp://{q_username}:{q_password}@{q_host}/") - - - connection_pool: Pool = Pool(get_connection, max_size=4, loop=loop) - - - async def get_channel() -> aio_pika.Channel: - async with connection_pool.acquire() as connection: - return await connection.channel() - - - channel_pool: Pool = Pool(get_channel, max_size=10, loop=loop) - - return channel_pool