diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..dc71603 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,8 @@ +.github/ + +.coverage/ +.coverage +coverage + +venv/ +.env diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..e73ede4 --- /dev/null +++ b/.env.example @@ -0,0 +1,25 @@ +CHUNK_SIZE= +COHERE_API_KEY= +D_RETRIEVER_SEARCH= +EMBEDDING_DIM= +K1_RETRIEVER_SEARCH= +K2_RETRIEVER_SEARCH= +MONGODB_HOST= +MONGODB_PASS= +MONGODB_PORT= +MONGODB_USER= +NEO4J_DB= +NEO4J_HOST= +NEO4J_PASSWORD= +NEO4J_PORT= +NEO4J_PROTOCOL= +NEO4J_USER= +OPENAI_API_KEY= +POSTGRES_HOST= +POSTGRES_PASS= +POSTGRES_PORT= +POSTGRES_USER= +RABBIT_HOST= +RABBIT_PASSWORD= +RABBIT_PORT= +RABBIT_USER= diff --git a/.github/workflows/production.yml b/.github/workflows/production.yml new file mode 100644 index 0000000..a1be27b --- /dev/null +++ b/.github/workflows/production.yml @@ -0,0 +1,12 @@ +name: Production CI/CD Pipeline + +on: + push: + branches: + - main + +jobs: + ci: + uses: TogetherCrew/operations/.github/workflows/ci.yml@main + secrets: + CC_TEST_REPORTER_ID: ${{ secrets.CC_TEST_REPORTER_ID }} \ No newline at end of file diff --git a/.github/workflows/start.staging.yml b/.github/workflows/start.staging.yml new file mode 100644 index 0000000..842e3bd --- /dev/null +++ b/.github/workflows/start.staging.yml @@ -0,0 +1,9 @@ +name: Staging CI/CD Pipeline + +on: pull_request + +jobs: + ci: + uses: TogetherCrew/operations/.github/workflows/ci.yml@main + secrets: + CC_TEST_REPORTER_ID: ${{ secrets.CC_TEST_REPORTER_ID }} \ No newline at end of file diff --git a/.gitignore b/.gitignore index 68bc17f..3a5bd6b 100644 --- a/.gitignore +++ b/.gitignore @@ -158,3 +158,6 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ + +hivemind-bot-env/* +main.ipynb \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..e2734a2 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,12 @@ +# It's recommended that we use `bullseye` for Python (alpine isn't suitable as it conflcts with numpy) +FROM python:3.11-bullseye AS base +WORKDIR /project +COPY . . +RUN pip3 install -r requirements.txt + +FROM base AS test +RUN chmod +x docker-entrypoint.sh +CMD ["./docker-entrypoint.sh"] + +FROM base AS prod +CMD ["python3", "celery", "-A", "celery_app.server", "worker", "-l", "INFO"] diff --git a/bot/__init__.py b/bot/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/bot/retrievers/__init__.py b/bot/retrievers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/bot/retrievers/forum_summary_retriever.py b/bot/retrievers/forum_summary_retriever.py new file mode 100644 index 0000000..1e04cea --- /dev/null +++ b/bot/retrievers/forum_summary_retriever.py @@ -0,0 +1,73 @@ +from bot.retrievers.summary_retriever_base import BaseSummarySearch +from llama_index.embeddings import BaseEmbedding +from tc_hivemind_backend.embeddings.cohere import CohereEmbedding + + +class ForumBasedSummaryRetriever(BaseSummarySearch): + def __init__( + self, + table_name: str, + dbname: str, + embedding_model: BaseEmbedding | CohereEmbedding = CohereEmbedding(), + ) -> None: + """ + the class for forum based data like discord and discourse + by default CohereEmbedding will be used. + """ + super().__init__(table_name, dbname, embedding_model=embedding_model) + + def retreive_metadata( + self, + query: str, + metadata_group1_key: str, + metadata_group2_key: str, + metadata_date_key: str, + similarity_top_k: int = 20, + ) -> tuple[set[str], set[str], set[str]]: + """ + retrieve the metadata information of the similar nodes with the query + + Parameters + ----------- + query : str + the user query to process + metadata_group1_key : str + the conversations grouping type 1 + in discord can be `channel`, and in discourse can be `category` + metadata_group2_key : str + the conversations grouping type 2 + in discord can be `thread`, and in discourse can be `topic` + metadata_date_key : str + the daily metadata saved key + similarity_top_k : int + the top k nodes to get as the retriever. + default is set as 20 + + + Returns + --------- + group1_data : set[str] + the similar summary nodes having the group1_data. + can be an empty set meaning no similar thread + conversations for it was available. + group2_data : set[str] + the similar summary nodes having the group2_data. + can be an empty set meaning no similar channel + conversations for it was available. + dates : set[str] + the similar daily conversations to the given query + """ + nodes = self.get_similar_nodes(query=query, similarity_top_k=similarity_top_k) + + group1_data: set[str] = set() + dates: set[str] = set() + group2_data: set[str] = set() + + for node in nodes: + if node.metadata[metadata_group1_key]: + group1_data.add(node.metadata[metadata_group1_key]) + if node.metadata[metadata_group2_key]: + group2_data.add(node.metadata[metadata_group2_key]) + dates.add(node.metadata[metadata_date_key]) + + return group1_data, group2_data, dates diff --git a/bot/retrievers/process_dates.py b/bot/retrievers/process_dates.py new file mode 100644 index 0000000..dba3217 --- /dev/null +++ b/bot/retrievers/process_dates.py @@ -0,0 +1,39 @@ +import logging +from datetime import timedelta + +from dateutil import parser + + +def process_dates(dates: list[str], d: int) -> list[str]: + """ + process the dates to be from `date - d` to `date + d` + + Parameters + ------------ + dates : list[str] + the list of dates given + d : int + to update the `dates` list to have `-d` and `+d` days + + + Returns + ---------- + dates_modified : list[str] + days added to it + """ + dates_modified: list[str] = [] + if dates != []: + lowest_date = min(parser.parse(date) for date in dates) + greatest_date = max(parser.parse(date) for date in dates) + + delta_days = timedelta(days=d) + + # the date condition + dt = lowest_date - delta_days + while dt <= greatest_date + delta_days: + dates_modified.append(dt.strftime("%Y-%m-%d")) + dt += timedelta(days=1) + else: + logging.warning("No dates given!") + + return dates_modified diff --git a/bot/retrievers/summary_retriever_base.py b/bot/retrievers/summary_retriever_base.py new file mode 100644 index 0000000..0095a7f --- /dev/null +++ b/bot/retrievers/summary_retriever_base.py @@ -0,0 +1,75 @@ +from llama_index import VectorStoreIndex +from llama_index.embeddings import BaseEmbedding +from llama_index.indices.query.schema import QueryBundle +from llama_index.schema import NodeWithScore +from tc_hivemind_backend.embeddings.cohere import CohereEmbedding +from tc_hivemind_backend.pg_vector_access import PGVectorAccess + + +class BaseSummarySearch: + def __init__( + self, + table_name: str, + dbname: str, + embedding_model: BaseEmbedding = CohereEmbedding(), + ) -> None: + """ + initialize the base summary search class + + In this class we're doing a similarity search + for available saved nodes under postgresql + + Parameters + ------------- + table_name : str + the table that summary data is saved + *Note:* Don't include the `data_` prefix of the table, + cause lamma_index would original include that. + dbname : str + the database name to access + similarity_top_k : int + the top k nodes to get as the retriever. + default is set as 20 + embedding_model : llama_index.embeddings.BaseEmbedding + the embedding model to use for doing embedding on the query string + default would be CohereEmbedding that we've written + """ + self.index = self._setup_index(table_name, dbname, embedding_model) + self.embedding_model = embedding_model + + def get_similar_nodes( + self, query: str, similarity_top_k: int = 20 + ) -> list[NodeWithScore]: + """ + get k similar nodes to the query. + Note: this funciton wold get the embedding + for the query to do the similarity search. + + Parameters + ------------ + query : str + the user query to process + similarity_top_k : int + the top k nodes to get as the retriever. + default is set as 20 + """ + retriever = self.index.as_retriever(similarity_top_k=similarity_top_k) + + query_embedding = self.embedding_model.get_text_embedding(text=query) + + query_bundle = QueryBundle(query_str=query, embedding=query_embedding) + nodes = retriever._retrieve(query_bundle) + + return nodes + + def _setup_index( + self, table_name: str, dbname: str, embedding_model: BaseEmbedding + ) -> VectorStoreIndex: + """ + setup the llama_index VectorStoreIndex + """ + pg_vector_access = PGVectorAccess( + table_name=table_name, dbname=dbname, embed_model=embedding_model + ) + index = pg_vector_access.load_index() + return index diff --git a/bot/retrievers/utils/__init__.py b/bot/retrievers/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/bot/retrievers/utils/load_hyperparams.py b/bot/retrievers/utils/load_hyperparams.py new file mode 100644 index 0000000..98db6ce --- /dev/null +++ b/bot/retrievers/utils/load_hyperparams.py @@ -0,0 +1,34 @@ +import os + +from dotenv import load_dotenv + + +def load_hyperparams() -> tuple[int, int, int]: + """ + load the k1, k2, and d hyperparams that are used for retrievers + + Returns + --------- + k1 : int + the value for the first summary search + to get the `k1` count similar nodes + k2 : int + the value for the secondary raw search + to get the `k2` count simliar nodes + d : int + the before and after day interval + """ + load_dotenv() + + k1 = os.getenv("K1_RETRIEVER_SEARCH") + k2 = os.getenv("K2_RETRIEVER_SEARCH") + d = os.getenv("D_RETRIEVER_SEARCH") + + if k1 is None: + raise ValueError("No `K1_RETRIEVER_SEARCH` available in .env file!") + if k2 is None: + raise ValueError("No `K2_RETRIEVER_SEARCH` available in .env file!") + if d is None: + raise ValueError("No `D_RETRIEVER_SEARCH` available in .env file!") + + return int(k1), int(k2), int(d) diff --git a/celery_app/__init__.py b/celery_app/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/celery_app/job_send.py b/celery_app/job_send.py new file mode 100644 index 0000000..c1f59ae --- /dev/null +++ b/celery_app/job_send.py @@ -0,0 +1,30 @@ +from tc_messageBroker import RabbitMQ +from tc_messageBroker.rabbit_mq.event import Event +from tc_messageBroker.rabbit_mq.queue import Queue + + +def job_send(broker_url, port, username, password, res): + rabbit_mq = RabbitMQ( + broker_url=broker_url, port=port, username=username, password=password + ) + + content = { + "uuid": "d99a1490-fba6-11ed-b9a9-0d29e7612dp8", + "data": f"some results {res}", + } + + rabbit_mq.connect(Queue.DISCORD_ANALYZER) + rabbit_mq.publish( + queue_name=Queue.DISCORD_ANALYZER, + event=Event.DISCORD_BOT.FETCH, + content=content, + ) + + +if __name__ == "__main__": + # TODO: read from .env + broker_url = "localhost" + port = 5672 + username = "root" + password = "pass" + job_send(broker_url, port, username, password, "CALLED FROM __main__") diff --git a/celery_app/server.py b/celery_app/server.py new file mode 100644 index 0000000..499aff0 --- /dev/null +++ b/celery_app/server.py @@ -0,0 +1,11 @@ +from celery import Celery +from utils.credentials import load_rabbitmq_credentials + +rabbit_creds = load_rabbitmq_credentials() +user = rabbit_creds["user"] +password = rabbit_creds["password"] +host = rabbit_creds["host"] +port = rabbit_creds["port"] + +app = Celery("celery_app/tasks", broker=f"pyamqp://{user}:{password}@{host}:{port}//") +app.autodiscover_tasks(["celery_app"]) diff --git a/celery_app/tasks.py b/celery_app/tasks.py new file mode 100644 index 0000000..5dd154e --- /dev/null +++ b/celery_app/tasks.py @@ -0,0 +1,29 @@ +from celery_app.job_send import job_send +from celery_app.server import app +from utils.credentials import load_rabbitmq_credentials + +# TODO: Write tasks that match our requirements + + +@app.task +def add(x, y): + rabbit_creds = load_rabbitmq_credentials() + username = rabbit_creds["user"] + password = rabbit_creds["password"] + broker_url = rabbit_creds["host"] + port = rabbit_creds["port"] + + res = x + y + job_send(broker_url, port, username, password, res) + + return res + + +@app.task +def mul(x, y): + return x * y + + +@app.task +def xsum(numbers): + return sum(numbers) diff --git a/discord_query.py b/discord_query.py new file mode 100644 index 0000000..24e762f --- /dev/null +++ b/discord_query.py @@ -0,0 +1,36 @@ +from llama_index import QueryBundle +from llama_index.schema import NodeWithScore +from tc_hivemind_backend.embeddings.cohere import CohereEmbedding +from utils.query_engine.discord_query_engine import prepare_discord_engine_auto_filter + + +def query_discord( + community_id: str, + query: str, +) -> tuple[str, list[NodeWithScore]]: + """ + query the llm using the query engine + + Parameters + ------------ + query_engine : BaseQueryEngine + the prepared query engine + query : str + the string question + + Returns + ---------- + response : str + the LLM response + source_nodes : list[llama_index.schema.NodeWithScore] + the source nodes that helped in answering the question + """ + query_engine = prepare_discord_engine_auto_filter( + community_id=community_id, + query=query, + ) + query_bundle = QueryBundle( + query_str=query, embedding=CohereEmbedding().get_text_embedding(text=query) + ) + response = query_engine.query(query_bundle) + return response.response, response.source_nodes diff --git a/docker-compose.example.yml b/docker-compose.example.yml new file mode 100644 index 0000000..0eccf2e --- /dev/null +++ b/docker-compose.example.yml @@ -0,0 +1,14 @@ +version: "3.9" + +services: + server: + build: + context: . + target: prod + dockerfile: Dockerfile + worker: + build: + context: . + target: prod + dockerfile: Dockerfile + command: python3 worker.py diff --git a/docker-compose.test.yml b/docker-compose.test.yml new file mode 100644 index 0000000..6375962 --- /dev/null +++ b/docker-compose.test.yml @@ -0,0 +1,88 @@ +version: "3.9" + +services: + app: + build: + context: . + target: test + dockerfile: Dockerfile + environment: + - PORT=3000 + - MONGODB_HOST=mongo + - MONGODB_PORT=27017 + - MONGODB_USER=root + - MONGODB_PASS=pass + - NEO4J_PROTOCOL=bolt + - NEO4J_HOST=neo4j + - NEO4J_PORT=7687 + - NEO4J_USER=neo4j + - NEO4J_PASSWORD=password + - NEO4J_DB=neo4j + - POSTGRES_HOST=postgres + - POSTGRES_USER=root + - POSTGRES_PASS=pass + - POSTGRES_PORT=5432 + - RABBIT_HOST=rabbitmq + - RABBIT_PORT=5672 + - RABBIT_USER=root + - RABBIT_PASSWORD=pass + - CHUNK_SIZE=512 + - EMBEDDING_DIM=1024 + - K1_RETRIEVER_SEARCH=20 + - K2_RETRIEVER_SEARCH=5 + - D_RETRIEVER_SEARCH=7 + - COHERE_API_KEY=some_credentials + - OPENAI_API_KEY=some_credentials2 + volumes: + - ./coverage:/project/coverage + depends_on: + neo4j: + condition: service_healthy + mongo: + condition: service_healthy + postgres: + condition: service_healthy + neo4j: + image: "neo4j:5.9.0" + environment: + - NEO4J_AUTH=neo4j/password + - NEO4J_PLUGINS=["apoc", "graph-data-science"] + - NEO4J_dbms_security_procedures_unrestricted=apoc.*,gds.* + healthcheck: + test: ["CMD" ,"wget", "http://localhost:7474"] + interval: 1m30s + timeout: 10s + retries: 2 + start_period: 40s + mongo: + image: "mongo:6.0.8" + environment: + - MONGO_INITDB_ROOT_USERNAME=root + - MONGO_INITDB_ROOT_PASSWORD=pass + healthcheck: + test: echo 'db.stats().ok' | mongosh localhost:27017/test --quiet + interval: 60s + timeout: 10s + retries: 2 + start_period: 40s + postgres: + image: "ankane/pgvector" + environment: + - POSTGRES_USER=root + - POSTGRES_PASSWORD=pass + healthcheck: + test: ["CMD-SHELL", "pg_isready"] + interval: 10s + timeout: 5s + retries: 5 + rabbitmq: + image: "rabbitmq:3-management-alpine" + environment: + - RABBITMQ_DEFAULT_USER=root + - RABBITMQ_DEFAULT_PASS=pass + healthcheck: + test: rabbitmq-diagnostics -q ping + interval: 30s + timeout: 30s + retries: 2 + start_period: 40s diff --git a/docker-entrypoint.sh b/docker-entrypoint.sh new file mode 100644 index 0000000..5127573 --- /dev/null +++ b/docker-entrypoint.sh @@ -0,0 +1,3 @@ +#!/usr/bin/env bash +python3 -m coverage run --omit=tests/* -m pytest . +python3 -m coverage lcov -o coverage/lcov.info \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..0a35833 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,21 @@ +numpy +llama-index>=0.9.26, <1.0.0 +pymongo +python-dotenv +pgvector +asyncpg +psycopg2-binary +sqlalchemy[asyncio] +async-sqlalchemy +python-pptx +tc-neo4j-lib +google-api-python-client +unstructured +cohere +neo4j>=5.14.1, <6.0.0 +coverage>=7.3.3, <8.0.0 +pytest>=7.4.3, <8.0.0 +python-dotenv==1.0.0 +tc-hivemind-backend==1.1.0 +celery>=5.3.6, <6.0.0 +guidance diff --git a/subquery.py b/subquery.py new file mode 100644 index 0000000..58dfd21 --- /dev/null +++ b/subquery.py @@ -0,0 +1,111 @@ +from guidance.models import OpenAIChat +from llama_index import QueryBundle, ServiceContext +from llama_index.core import BaseQueryEngine +from llama_index.query_engine import SubQuestionQueryEngine +from llama_index.question_gen.guidance_generator import GuidanceQuestionGenerator +from llama_index.schema import NodeWithScore +from llama_index.tools import QueryEngineTool, ToolMetadata +from tc_hivemind_backend.embeddings.cohere import CohereEmbedding +from utils.query_engine import prepare_discord_engine_auto_filter + + +def query_multiple_source( + query: str, + community_id: str, + discord: bool, + discourse: bool, + gdrive: bool, + notion: bool, + telegram: bool, + github: bool, +) -> tuple[str, list[NodeWithScore]]: + """ + query multiple platforms and get an answer from the multiple + + Parameters + ------------ + query : str + the user question + community_id : str + the community id to get their data + discord : bool + if `True` then add the engine to the subquery_generator + discourse : bool + if `True` then add the engine to the subquery_generator + gdrive : bool + if `True` then add the engine to the subquery_generator + notion : bool + if `True` then add the engine to the subquery_generator + telegram : bool + if `True` then add the engine to the subquery_generator + github : bool + if `True` then add the engine to the subquery_generator + + + Returns + -------- + response : str, + the response to the user query from the LLM + using the engines of the given platforms (pltform equal to True) + source_nodes : list[NodeWithScore] + the list of nodes that were source of answering + """ + query_engine_tools: list[QueryEngineTool] = [] + tools: list[ToolMetadata] = [] + + discord_query_engine: BaseQueryEngine + # discourse_query_engine: BaseQueryEngine + # gdrive_query_engine: BaseQueryEngine + # notion_query_engine: BaseQueryEngine + # telegram_query_engine: BaseQueryEngine + # github_query_engine: BaseQueryEngine + + # query engine perparation + # tools_metadata and query_engine_tools + if discord: + discord_query_engine = prepare_discord_engine_auto_filter( + community_id, + query, + similarity_top_k=None, + d=None, + ) + tool_metadata = ToolMetadata( + name="Discord", + description="Contains messages and summaries of conversations from the Discord platform of the community", + ) + + tools.append(tool_metadata) + query_engine_tools.append( + QueryEngineTool( + query_engine=discord_query_engine, + metadata=tool_metadata, + ) + ) + + if discourse: + raise NotImplementedError + if gdrive: + raise NotImplementedError + if notion: + raise NotImplementedError + if telegram: + raise NotImplementedError + if github: + raise NotImplementedError + + question_gen = GuidanceQuestionGenerator.from_defaults( + guidance_llm=OpenAIChat("gpt-3.5-turbo"), + verbose=False, + ) + embed_model = CohereEmbedding() + service_context = ServiceContext.from_defaults(embed_model=embed_model) + s_engine = SubQuestionQueryEngine.from_defaults( + question_gen=question_gen, + query_engine_tools=query_engine_tools, + use_async=False, + service_context=service_context, + ) + query_embedding = embed_model.get_text_embedding(text=query) + response = s_engine.query(QueryBundle(query_str=query, embedding=query_embedding)) + + return response.response, response.source_nodes diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/test_discord_summary_retriever.py b/tests/unit/test_discord_summary_retriever.py new file mode 100644 index 0000000..d5fafa3 --- /dev/null +++ b/tests/unit/test_discord_summary_retriever.py @@ -0,0 +1,82 @@ +from datetime import timedelta +from functools import partial +from unittest import TestCase +from unittest.mock import MagicMock + +from bot.retrievers.forum_summary_retriever import ForumBasedSummaryRetriever +from dateutil import parser +from llama_index import Document, MockEmbedding, ServiceContext, VectorStoreIndex + + +class TestDiscordSummaryRetriever(TestCase): + def test_initialize_class(self): + ForumBasedSummaryRetriever._setup_index = MagicMock() + documents: list[Document] = [] + all_dates: list[str] = [] + + for i in range(30): + date = parser.parse("2023-08-01") + timedelta(days=i) + doc_date = date.strftime("%Y-%m-%d") + doc = Document( + text="SAMPLESAMPLESAMPLE", + metadata={ + "thread": f"thread{i % 5}", + "channel": f"channel{i % 3}", + "date": doc_date, + }, + ) + all_dates.append(doc_date) + documents.append(doc) + + mock_embedding_model = partial(MockEmbedding, embed_dim=1024) + + service_context = ServiceContext.from_defaults( + llm=None, embed_model=mock_embedding_model() + ) + ForumBasedSummaryRetriever._setup_index.return_value = ( + VectorStoreIndex.from_documents( + documents=[doc], service_context=service_context + ) + ) + + base_summary_search = ForumBasedSummaryRetriever( + table_name="sample", + dbname="sample", + embedding_model=mock_embedding_model(), + ) + channels, threads, dates = base_summary_search.retreive_metadata( + query="what is samplesample?", + similarity_top_k=5, + metadata_group1_key="channel", + metadata_group2_key="thread", + metadata_date_key="date", + ) + self.assertIsInstance(threads, set) + self.assertIsInstance(channels, set) + self.assertIsInstance(dates, set) + + self.assertTrue( + threads.issubset( + set( + [ + "thread0", + "thread1", + "thread2", + "thread3", + "thread4", + ] + ) + ) + ) + self.assertTrue( + channels.issubset( + set( + [ + "channel0", + "channel1", + "channel2", + ] + ) + ) + ) + self.assertTrue(dates.issubset(all_dates)) diff --git a/tests/unit/test_load_credentials.py b/tests/unit/test_load_credentials.py new file mode 100644 index 0000000..5bfa795 --- /dev/null +++ b/tests/unit/test_load_credentials.py @@ -0,0 +1,41 @@ +import unittest + +from utils.credentials import load_postgres_credentials, load_rabbitmq_credentials + + +class TestCredentialLoadings(unittest.TestCase): + def test_postgresql_envs_check_type(self): + postgres_creds = load_postgres_credentials() + + self.assertIsInstance(postgres_creds, dict) + + def test_postgresql_envs_values(self): + postgres_creds = load_postgres_credentials() + + self.assertNotEqual(postgres_creds["user"], None) + self.assertNotEqual(postgres_creds["password"], None) + self.assertNotEqual(postgres_creds["host"], None) + self.assertNotEqual(postgres_creds["port"], None) + + self.assertIsInstance(postgres_creds["user"], str) + self.assertIsInstance(postgres_creds["password"], str) + self.assertIsInstance(postgres_creds["host"], str) + self.assertIsInstance(postgres_creds["port"], str) + + def test_rabbitmq_envs_check_type(self): + rabbitmq_creds = load_rabbitmq_credentials() + + self.assertIsInstance(rabbitmq_creds, dict) + + def test_rabbitmq_envs_values(self): + rabbitmq_creds = load_postgres_credentials() + + self.assertNotEqual(rabbitmq_creds["user"], None) + self.assertNotEqual(rabbitmq_creds["password"], None) + self.assertNotEqual(rabbitmq_creds["host"], None) + self.assertNotEqual(rabbitmq_creds["port"], None) + + self.assertIsInstance(rabbitmq_creds["user"], str) + self.assertIsInstance(rabbitmq_creds["password"], str) + self.assertIsInstance(rabbitmq_creds["host"], str) + self.assertIsInstance(rabbitmq_creds["port"], str) diff --git a/tests/unit/test_load_retriever_hyperparameters.py b/tests/unit/test_load_retriever_hyperparameters.py new file mode 100644 index 0000000..1f9c2fa --- /dev/null +++ b/tests/unit/test_load_retriever_hyperparameters.py @@ -0,0 +1,73 @@ +import unittest +from unittest.mock import patch + +from bot.retrievers.utils.load_hyperparams import load_hyperparams + + +class TestLoadHyperparams(unittest.TestCase): + @patch("os.getenv") + def test_valid_hyperparams(self, mock_getenv): + mock_getenv.side_effect = lambda x: { + "K1_RETRIEVER_SEARCH": "10", + "K2_RETRIEVER_SEARCH": "20", + "D_RETRIEVER_SEARCH": "30", + }.get(x) + result = load_hyperparams() + self.assertEqual(result, (10, 20, 30)) + + @patch("os.getenv") + def test_missing_k1(self, mock_getenv): + mock_getenv.side_effect = lambda x: { + "K2_RETRIEVER_SEARCH": "20", + "D_RETRIEVER_SEARCH": "30", + }.get(x) + with self.assertRaises(ValueError): + load_hyperparams() + + @patch("os.getenv") + def test_missing_k2(self, mock_getenv): + mock_getenv.side_effect = lambda x: { + "K1_RETRIEVER_SEARCH": "10", + "D_RETRIEVER_SEARCH": "30", + }.get(x) + with self.assertRaises(ValueError): + load_hyperparams() + + @patch("os.getenv") + def test_missing_d(self, mock_getenv): + mock_getenv.side_effect = lambda x: { + "K1_RETRIEVER_SEARCH": "10", + "K2_RETRIEVER_SEARCH": "20", + }.get(x) + with self.assertRaises(ValueError): + load_hyperparams() + + @patch("os.getenv") + def test_invalid_k1(self, mock_getenv): + mock_getenv.side_effect = lambda x: { + "K1_RETRIEVER_SEARCH": "invalid", + "K2_RETRIEVER_SEARCH": "20", + "D_RETRIEVER_SEARCH": "30", + }.get(x) + with self.assertRaises(ValueError): + load_hyperparams() + + @patch("os.getenv") + def test_invalid_k2(self, mock_getenv): + mock_getenv.side_effect = lambda x: { + "K1_RETRIEVER_SEARCH": "10", + "K2_RETRIEVER_SEARCH": "invalid", + "D_RETRIEVER_SEARCH": "30", + }.get(x) + with self.assertRaises(ValueError): + load_hyperparams() + + @patch("os.getenv") + def test_invalid_d(self, mock_getenv): + mock_getenv.side_effect = lambda x: { + "K1_RETRIEVER_SEARCH": "10", + "K2_RETRIEVER_SEARCH": "20", + "D_RETRIEVER_SEARCH": "invalid", + }.get(x) + with self.assertRaises(ValueError): + load_hyperparams() diff --git a/tests/unit/test_prepare_discord_query_engine.py b/tests/unit/test_prepare_discord_query_engine.py new file mode 100644 index 0000000..72ad1a6 --- /dev/null +++ b/tests/unit/test_prepare_discord_query_engine.py @@ -0,0 +1,50 @@ +import os +import unittest + +from llama_index.core import BaseQueryEngine +from llama_index.vector_stores import ExactMatchFilter, FilterCondition, MetadataFilters +from utils.query_engine.discord_query_engine import prepare_discord_engine + + +class TestPrepareDiscordEngine(unittest.TestCase): + def setUp(self): + # Set up environment variables for testing + os.environ["CHUNK_SIZE"] = "128" + os.environ["EMBEDDING_DIM"] = "256" + os.environ["K1_RETRIEVER_SEARCH"] = "20" + os.environ["K2_RETRIEVER_SEARCH"] = "5" + os.environ["D_RETRIEVER_SEARCH"] = "3" + + def test_prepare_discord_engine(self): + community_id = "123456" + thread_names = ["thread1", "thread2"] + channel_names = ["channel1", "channel2"] + days = ["2022-01-01", "2022-01-02"] + + # Call the function + query_engine = prepare_discord_engine( + community_id, + thread_names, + channel_names, + days, + testing=True, + ) + + # Assertions + self.assertIsInstance(query_engine, BaseQueryEngine) + + expected_filter = MetadataFilters( + filters=[ + ExactMatchFilter(key="thread", value="thread1"), + ExactMatchFilter(key="thread", value="thread2"), + ExactMatchFilter(key="channel", value="channel1"), + ExactMatchFilter(key="channel", value="channel2"), + ExactMatchFilter(key="date", value="2022-01-01"), + ExactMatchFilter(key="date", value="2022-01-02"), + ], + condition=FilterCondition.OR, + ) + + self.assertEqual(query_engine.retriever._filters, expected_filter) + # this is the secondary search, so K2 should be for this + self.assertEqual(query_engine.retriever._similarity_top_k, 5) diff --git a/tests/unit/test_process_dates_forum_retriever_search.py b/tests/unit/test_process_dates_forum_retriever_search.py new file mode 100644 index 0000000..9b44f3c --- /dev/null +++ b/tests/unit/test_process_dates_forum_retriever_search.py @@ -0,0 +1,42 @@ +import unittest + +from bot.retrievers.process_dates import process_dates + + +class TestProcessDates(unittest.TestCase): + def test_process_dates_with_valid_input(self): + # Test with a valid input + input_dates = ["2023-01-01", "2023-01-03", "2023-01-05"] + d = 2 + expected_output = [ + "2022-12-30", + "2022-12-31", + "2023-01-01", + "2023-01-02", + "2023-01-03", + "2023-01-04", + "2023-01-05", + "2023-01-06", + "2023-01-07", + ] + self.assertEqual(process_dates(input_dates, d), expected_output) + + def test_process_dates_with_empty_input(self): + # Test with an empty input + input_dates = [] + d = 2 + expected_output = [] + self.assertEqual(process_dates(input_dates, d), expected_output) + + def test_process_dates_with_single_date(self): + # Test with a single date in the input + input_dates = ["2023-01-01"] + d = 2 + expected_output = [ + "2022-12-30", + "2022-12-31", + "2023-01-01", + "2023-01-02", + "2023-01-03", + ] + self.assertEqual(process_dates(input_dates, d), expected_output) diff --git a/tests/unit/test_summary_retriever_base.py b/tests/unit/test_summary_retriever_base.py new file mode 100644 index 0000000..14180ac --- /dev/null +++ b/tests/unit/test_summary_retriever_base.py @@ -0,0 +1,30 @@ +from functools import partial +from unittest import TestCase +from unittest.mock import MagicMock + +from bot.retrievers.summary_retriever_base import BaseSummarySearch +from llama_index import Document, MockEmbedding, ServiceContext, VectorStoreIndex +from llama_index.schema import NodeWithScore + + +class TestSummaryRetrieverBase(TestCase): + def test_initialize_class(self): + BaseSummarySearch._setup_index = MagicMock() + doc = Document(text="SAMPLESAMPLESAMPLE") + mock_embedding_model = partial(MockEmbedding, embed_dim=1024) + + service_context = ServiceContext.from_defaults( + llm=None, embed_model=mock_embedding_model() + ) + BaseSummarySearch._setup_index.return_value = VectorStoreIndex.from_documents( + documents=[doc], service_context=service_context + ) + + base_summary_search = BaseSummarySearch( + table_name="sample", + dbname="sample", + embedding_model=mock_embedding_model(), + ) + nodes = base_summary_search.get_similar_nodes(query="what is samplesample?") + self.assertIsInstance(nodes, list) + self.assertIsInstance(nodes[0], NodeWithScore) diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/credentials.py b/utils/credentials.py new file mode 100644 index 0000000..234c82c --- /dev/null +++ b/utils/credentials.py @@ -0,0 +1,55 @@ +import os + +from dotenv import load_dotenv + + +def load_postgres_credentials() -> dict[str, str]: + """ + load posgresql db credentials from .env + + Returns: + --------- + postgres_creds : dict[str, Any] + postgresql credentials + a dictionary representive of + `user`: str + `password` : str + `host` : str + `port` : int + """ + load_dotenv() + + postgres_creds = {} + + postgres_creds["user"] = os.getenv("POSTGRES_USER", "") + postgres_creds["password"] = os.getenv("POSTGRES_PASS", "") + postgres_creds["host"] = os.getenv("POSTGRES_HOST", "") + postgres_creds["port"] = os.getenv("POSTGRES_PORT", "") + + return postgres_creds + + +def load_rabbitmq_credentials() -> dict[str, str]: + """ + load rabbitmq credentials from .env + + Returns: + --------- + rabbitmq_creds : dict[str, Any] + rabbitmq credentials + a dictionary representive of + `user`: str + `password` : str + `host` : str + `port` : int + """ + load_dotenv() + + rabbitmq_creds = {} + + rabbitmq_creds["user"] = os.getenv("RABBIT_USER", "") + rabbitmq_creds["password"] = os.getenv("RABBIT_PASSWORD", "") + rabbitmq_creds["host"] = os.getenv("RABBIT_HOST", "") + rabbitmq_creds["port"] = os.getenv("RABBIT_PORT", "") + + return rabbitmq_creds diff --git a/utils/query_engine/__init__.py b/utils/query_engine/__init__.py new file mode 100644 index 0000000..fad06f9 --- /dev/null +++ b/utils/query_engine/__init__.py @@ -0,0 +1,2 @@ +# flake8: noqa +from .discord_query_engine import prepare_discord_engine_auto_filter diff --git a/utils/query_engine/discord_query_engine.py b/utils/query_engine/discord_query_engine.py new file mode 100644 index 0000000..6a29833 --- /dev/null +++ b/utils/query_engine/discord_query_engine.py @@ -0,0 +1,148 @@ +from bot.retrievers.forum_summary_retriever import ForumBasedSummaryRetriever +from bot.retrievers.process_dates import process_dates +from bot.retrievers.utils.load_hyperparams import load_hyperparams +from llama_index.core import BaseQueryEngine +from llama_index.vector_stores import ExactMatchFilter, FilterCondition, MetadataFilters +from tc_hivemind_backend.embeddings.cohere import CohereEmbedding +from tc_hivemind_backend.pg_vector_access import PGVectorAccess + + +def prepare_discord_engine( + community_id: str, + thread_names: list[str], + channel_names: list[str], + days: list[str], + similarity_top_k: int | None = None, + **kwarg, +) -> BaseQueryEngine: + """ + query the discord database using filters given + and give an anwer to the given query using the LLM + + Parameters + ------------ + guild_id : str + the discord guild data to query + query : str + the query (question) of the user + thread_names : list[str] + the given threads to search for + channel_names : list[str] + the given channels to search for + days : list[str] + the given days to search for + similarity_top_k : int | None + the k similar results to use when querying the data + if `None` will load from `.env` file + ** kwargs : + testing : bool + whether to setup the PGVectorAccess in testing mode + + Returns + --------- + query_engine : BaseQueryEngine + the created query engine with the filters + """ + table_name = "discord" + dbname = f"community_{community_id}" + + testing = kwarg.get("testing", False) + + pg_vector = PGVectorAccess( + table_name=table_name, + dbname=dbname, + testing=testing, + embed_model=CohereEmbedding(), + ) + index = pg_vector.load_index() + if similarity_top_k is None: + _, similarity_top_k, _ = load_hyperparams() + + thread_filters: list[ExactMatchFilter] = [] + channel_filters: list[ExactMatchFilter] = [] + day_filters: list[ExactMatchFilter] = [] + + for channel in channel_names: + channel_updated = channel.replace("'", "''") + channel_filters.append(ExactMatchFilter(key="channel", value=channel_updated)) + + for thread in thread_names: + thread_updated = thread.replace("'", "''") + thread_filters.append(ExactMatchFilter(key="thread", value=thread_updated)) + + for day in days: + day_filters.append(ExactMatchFilter(key="date", value=day)) + + all_filters: list[ExactMatchFilter] = [] + all_filters.extend(thread_filters) + all_filters.extend(channel_filters) + all_filters.extend(day_filters) + + filters = MetadataFilters(filters=all_filters, condition=FilterCondition.OR) + + query_engine = index.as_query_engine( + filters=filters, similarity_top_k=similarity_top_k + ) + + return query_engine + + +def prepare_discord_engine_auto_filter( + community_id: str, + query: str, + similarity_top_k: int | None = None, + d: int | None = None, +) -> BaseQueryEngine: + """ + get the query engine and do the filtering automatically. + By automatically we mean, it would first query the summaries + to get the metadata filters + + Parameters + ----------- + guild_id : str + the discord guild data to query + query : str + the query (question) of the user + similarity_top_k : int | None + the value for the initial summary search + to get the `k2` count simliar nodes + if `None`, then would read from `.env` + d : int + this would make the secondary search (`prepare_discord_engine`) + to be done on the `metadata.date - d` to `metadata.date + d` + + + Returns + --------- + query_engine : BaseQueryEngine + the created query engine with the filters + """ + table_name = "discord_summary" + dbname = f"community_{community_id}" + + if d is None: + _, _, d = load_hyperparams() + if similarity_top_k is None: + similarity_top_k, _, _ = load_hyperparams() + + discord_retriever = ForumBasedSummaryRetriever(table_name=table_name, dbname=dbname) + + channels, threads, dates = discord_retriever.retreive_metadata( + query=query, + metadata_group1_key="channel", + metadata_group2_key="thread", + metadata_date_key="date", + similarity_top_k=similarity_top_k, + ) + + dates_modified = process_dates(list(dates), d) + + engine = prepare_discord_engine( + community_id=community_id, + query=query, + thread_names=list(threads), + channel_names=list(channels), + days=dates_modified, + ) + return engine diff --git a/worker.py b/worker.py new file mode 100644 index 0000000..48ddb05 --- /dev/null +++ b/worker.py @@ -0,0 +1,39 @@ +from celery_app.tasks import add +from tc_messageBroker import RabbitMQ +from tc_messageBroker.rabbit_mq.event import Event +from tc_messageBroker.rabbit_mq.queue import Queue +from utils.credentials import load_rabbitmq_credentials + + +# TODO: Update according to our requirements +def do_something(recieved_data): + message = "Calculation Results:" + print(message) + print(f"recieved_data: {recieved_data}") + add.delay(20, 14) + + +def job_recieve(broker_url, port, username, password): + rabbit_mq = RabbitMQ( + broker_url=broker_url, port=port, username=username, password=password + ) + + # TODO: Update according to our requirements + rabbit_mq.on_event(Event.HIVEMIND.INTERACTION_CREATED, do_something) + rabbit_mq.connect(Queue.HIVEMIND) + rabbit_mq.consume(Queue.HIVEMIND) + + if rabbit_mq.channel is not None: + rabbit_mq.channel.start_consuming() + else: + print("Connection to broker was not successful!") + + +if __name__ == "__main__": + rabbit_creds = load_rabbitmq_credentials() + username = rabbit_creds["user"] + password = rabbit_creds["password"] + broker_url = rabbit_creds["host"] + port = rabbit_creds["port"] + + job_recieve(broker_url, port, username, password)