Skip to content

Commit

Permalink
Revert rbtmq connection pool to single connection
Browse files Browse the repository at this point in the history
  • Loading branch information
maximusunc committed Jul 21, 2023
1 parent ff2cfcd commit cda8db2
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 51 deletions.
26 changes: 14 additions & 12 deletions src/aragorn_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from src.openapi_constructor import construct_open_api_schema
from src.common import async_query, sync_query, status_query
from src.default_queries import default_input_sync, default_input_async
from src.util import get_channel_pool
from src.otel_config import configure_otel

# declare the FastAPI details
Expand Down Expand Up @@ -43,16 +42,18 @@
config["handlers"]["console"]["level"] = log_level
config["handlers"]["file"]["level"] = log_level
config["loggers"]["src"]["level"] = log_level
#config["loggers"]["aio_pika"]["level"] = log_level
config["loggers"]["aio_pika"]["level"] = log_level

# load the log config
logging.config.dictConfig(config)

# create a logger
logger = logging.getLogger(__name__)

# Get rabbitmq channel pool
channel_pool = get_channel_pool()
# Get rabbitmq connection
q_username = os.environ.get("QUEUE_USER", "guest")
q_password = os.environ.get("QUEUE_PW", "guest")
q_host = os.environ.get("QUEUE_HOST", "127.0.0.1")

# declare the directory where the async data files will exist
queue_file_dir = "./queue-files"
Expand Down Expand Up @@ -111,13 +112,16 @@ async def subservice_callback(response: PDResponse, guid: str) -> int:
"""
# init the return html status code
ret_val: int = 200
# init the rbtmq connection
connection = None

logger.debug(f"{guid}: Receiving sub-service callback")

try:
async with channel_pool.acquire() as channel:
connection = await aio_pika.connect_robust(f"amqp://{q_username}:{q_password}@{q_host}/")
async with connection:
channel = await connection.channel()
await channel.get_queue(guid, ensure=True)

# create a file path/name
fname = "".join(random.choices(string.ascii_lowercase, k=12))
file_name = f"{queue_file_dir}/{guid}-{fname}-async-data.json"
Expand All @@ -133,17 +137,15 @@ async def subservice_callback(response: PDResponse, guid: str) -> int:
logger.debug(f"{guid}: Callback message published to queue.")
else:
logger.error(f"{guid}: Callback message publishing to queue failed, type: {type(publish_val)}")
# if isinstance(publish_val, spec.Basic.Ack):
# logger.info(f'{guid}: Callback message published to queue.')
# else:
# # set the html error code
# ret_val = 422
# logger.error(f'{guid}: Callback message publishing to queue failed, type: {type(publish_val)}')

except Exception as e:
logger.exception(f"Exception detected while handling sub-service callback using guid {guid}", e)
# set the html status code
ret_val = 500
finally:
# close rbtmq connection if it exists
if connection:
await connection.close()

# return the response code
return ret_val
Expand Down
40 changes: 32 additions & 8 deletions src/service_aggregator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Literature co-occurrence support."""
import aio_pika
import json
import logging
import asyncio
Expand All @@ -10,7 +11,7 @@
from string import Template

from functools import partial
from src.util import create_log_entry, get_channel_pool
from src.util import create_log_entry
from src.operations import sort_results_score, filter_results_top_n, filter_kgraph_orphans, filter_message_top_n
from src.process_db import add_item
from datetime import datetime
Expand All @@ -27,9 +28,6 @@

logger = logging.getLogger(__name__)

# Get rabbitmq channel pool
channel_pool = get_channel_pool()

# declare the directory where the async data files will exist
queue_file_dir = "./queue-files"

Expand Down Expand Up @@ -279,25 +277,45 @@ async def collect_callback_responses(guid, num_queries):

return responses

async def get_pika_connection():
q_username = os.environ.get("QUEUE_USER", "guest")
q_password = os.environ.get("QUEUE_PW", "guest")
q_host = os.environ.get("QUEUE_HOST", "127.0.0.1")
connection = await aio_pika.connect_robust(host=q_host, login=q_username, password=q_password)
return connection


async def create_queue(guid):
connection = None
try:
async with channel_pool.acquire() as channel:
connection = await get_pika_connection()
async with connection:
channel = await connection.channel()
# declare the queue using the guid as the key
queue = await channel.declare_queue(guid)
except Exception as e:
logger.error(f"{guid}: Failed to create queue.")
raise e
finally:
if connection:
await connection.close()


async def delete_queue(guid):
connection = None
try:
async with channel_pool.acquire() as channel:
# declare the queue using the guid as the key
connection = await get_pika_connection()
async with connection:
channel = await connection.channel()
# delete the queue using the guid as the key
queue = await channel.queue_delete(guid)
except Exception:
logger.error(f"{guid}: Failed to delete queue.")
# Deleting queue isn't essential, so we will continue
finally:
if connection:
await connection.close()


def has_unique_nodes(result):
"""Given a result, return True if all nodes are unique, False otherwise"""
Expand Down Expand Up @@ -331,8 +349,11 @@ async def check_for_messages(guid, num_queries, num_previously_received=0):
responses = []
CONNECTION_TIMEOUT = 1 * 60 # 1 minutes
num_responses = num_previously_received
connection = None
try:
async with channel_pool.acquire() as channel:
connection = await get_pika_connection()
async with connection:
channel = await connection.channel()
queue = await channel.get_queue(guid, ensure=True)
# wait for the response. Timeout after
async with queue.iterator(timeout=CONNECTION_TIMEOUT) as queue_iter:
Expand Down Expand Up @@ -369,6 +390,9 @@ async def check_for_messages(guid, num_queries, num_previously_received=0):
except Exception as e:
logger.error(f"{guid}: Exception {e}. Returning {num_responses} results we have so far.")
return responses, True
finally:
if connection:
await connection.close()

return responses, complete

Expand Down
31 changes: 0 additions & 31 deletions src/util.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
"""Common Aragorn Utilities."""
import aio_pika
from aio_pika.abc import AbstractRobustConnection
from aio_pika.pool import Pool
import asyncio
import datetime
import os


def create_log_entry(msg: str, err_level, timestamp = datetime.datetime.now().isoformat(), code=None) -> dict:
Expand All @@ -13,29 +8,3 @@ def create_log_entry(msg: str, err_level, timestamp = datetime.datetime.now().is

# return to the caller
return ret_val


def get_channel_pool():
# get the queue connection params
q_username = os.environ.get("QUEUE_USER", "guest")
q_password = os.environ.get("QUEUE_PW", "guest")
q_host = os.environ.get("QUEUE_HOST", "127.0.0.1")

loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

async def get_connection() -> AbstractRobustConnection:
return await aio_pika.connect_robust(f"amqp://{q_username}:{q_password}@{q_host}/")


connection_pool: Pool = Pool(get_connection, max_size=4, loop=loop)


async def get_channel() -> aio_pika.Channel:
async with connection_pool.acquire() as connection:
return await connection.channel()


channel_pool: Pool = Pool(get_channel, max_size=10, loop=loop)

return channel_pool

0 comments on commit cda8db2

Please sign in to comment.