From a467be009103917ef15bb59216f6f6e0155be672 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Thu, 28 Dec 2023 16:25:50 +0330 Subject: [PATCH 01/25] feat: initializing the bot! --- .dockerignore | 8 + .github/workflows/production.yml | 12 ++ .github/workflows/start.staging.yml | 9 ++ .gitignore | 2 + Dockerfile | 12 ++ celery_app/__init__.py | 0 celery_app/job_send.py | 30 ++++ celery_app/server.py | 5 + celery_app/tasks.py | 27 ++++ discord_query.py | 147 ++++++++++++++++++ docker-compose.example.yml | 14 ++ docker-compose.test.yml | 71 +++++++++ docker-entrypoint.sh | 3 + requirements.txt | 20 +++ retrievers/__init__.py | 0 retrievers/forum_summary_retriever.py | 74 +++++++++ retrievers/process_dates.py | 39 +++++ retrievers/summary_retriever_base.py | 72 +++++++++ retrievers/utils/__init__.py | 0 retrievers/utils/load_hyperparams.py | 34 ++++ tests/__init__.py | 0 tests/integration/__init__.py | 0 tests/unit/__init__.py | 0 tests/unit/test_discord_summary_retriever.py | 84 ++++++++++ .../test_load_retriever_hyperparameters.py | 73 +++++++++ ...st_process_dates_forum_retriever_search.py | 42 +++++ tests/unit/test_summary_retriever_base.py | 30 ++++ worker.py | 39 +++++ 28 files changed, 847 insertions(+) create mode 100644 .dockerignore create mode 100644 .github/workflows/production.yml create mode 100644 .github/workflows/start.staging.yml create mode 100644 Dockerfile create mode 100644 celery_app/__init__.py create mode 100644 celery_app/job_send.py create mode 100644 celery_app/server.py create mode 100644 celery_app/tasks.py create mode 100644 discord_query.py create mode 100644 docker-compose.example.yml create mode 100644 docker-compose.test.yml create mode 100644 docker-entrypoint.sh create mode 100644 requirements.txt create mode 100644 retrievers/__init__.py create mode 100644 retrievers/forum_summary_retriever.py create mode 100644 retrievers/process_dates.py create mode 100644 retrievers/summary_retriever_base.py create mode 100644 retrievers/utils/__init__.py create mode 100644 retrievers/utils/load_hyperparams.py create mode 100644 tests/__init__.py create mode 100644 tests/integration/__init__.py create mode 100644 tests/unit/__init__.py create mode 100644 tests/unit/test_discord_summary_retriever.py create mode 100644 tests/unit/test_load_retriever_hyperparameters.py create mode 100644 tests/unit/test_process_dates_forum_retriever_search.py create mode 100644 tests/unit/test_summary_retriever_base.py create mode 100644 worker.py diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..dc71603 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,8 @@ +.github/ + +.coverage/ +.coverage +coverage + +venv/ +.env diff --git a/.github/workflows/production.yml b/.github/workflows/production.yml new file mode 100644 index 0000000..a1be27b --- /dev/null +++ b/.github/workflows/production.yml @@ -0,0 +1,12 @@ +name: Production CI/CD Pipeline + +on: + push: + branches: + - main + +jobs: + ci: + uses: TogetherCrew/operations/.github/workflows/ci.yml@main + secrets: + CC_TEST_REPORTER_ID: ${{ secrets.CC_TEST_REPORTER_ID }} \ No newline at end of file diff --git a/.github/workflows/start.staging.yml b/.github/workflows/start.staging.yml new file mode 100644 index 0000000..842e3bd --- /dev/null +++ b/.github/workflows/start.staging.yml @@ -0,0 +1,9 @@ +name: Staging CI/CD Pipeline + +on: pull_request + +jobs: + ci: + uses: TogetherCrew/operations/.github/workflows/ci.yml@main + secrets: + CC_TEST_REPORTER_ID: ${{ secrets.CC_TEST_REPORTER_ID }} \ No newline at end of file diff --git a/.gitignore b/.gitignore index 68bc17f..1cd6533 100644 --- a/.gitignore +++ b/.gitignore @@ -158,3 +158,5 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ + +hivemind-bot-env/* \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..e2734a2 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,12 @@ +# It's recommended that we use `bullseye` for Python (alpine isn't suitable as it conflcts with numpy) +FROM python:3.11-bullseye AS base +WORKDIR /project +COPY . . +RUN pip3 install -r requirements.txt + +FROM base AS test +RUN chmod +x docker-entrypoint.sh +CMD ["./docker-entrypoint.sh"] + +FROM base AS prod +CMD ["python3", "celery", "-A", "celery_app.server", "worker", "-l", "INFO"] diff --git a/celery_app/__init__.py b/celery_app/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/celery_app/job_send.py b/celery_app/job_send.py new file mode 100644 index 0000000..c1f59ae --- /dev/null +++ b/celery_app/job_send.py @@ -0,0 +1,30 @@ +from tc_messageBroker import RabbitMQ +from tc_messageBroker.rabbit_mq.event import Event +from tc_messageBroker.rabbit_mq.queue import Queue + + +def job_send(broker_url, port, username, password, res): + rabbit_mq = RabbitMQ( + broker_url=broker_url, port=port, username=username, password=password + ) + + content = { + "uuid": "d99a1490-fba6-11ed-b9a9-0d29e7612dp8", + "data": f"some results {res}", + } + + rabbit_mq.connect(Queue.DISCORD_ANALYZER) + rabbit_mq.publish( + queue_name=Queue.DISCORD_ANALYZER, + event=Event.DISCORD_BOT.FETCH, + content=content, + ) + + +if __name__ == "__main__": + # TODO: read from .env + broker_url = "localhost" + port = 5672 + username = "root" + password = "pass" + job_send(broker_url, port, username, password, "CALLED FROM __main__") diff --git a/celery_app/server.py b/celery_app/server.py new file mode 100644 index 0000000..c1c44d4 --- /dev/null +++ b/celery_app/server.py @@ -0,0 +1,5 @@ +from celery import Celery + +# TODO: read from .env +app = Celery("celery_app/tasks", broker="pyamqp://root:pass@localhost//") +app.autodiscover_tasks(["celery_app"]) diff --git a/celery_app/tasks.py b/celery_app/tasks.py new file mode 100644 index 0000000..73652a7 --- /dev/null +++ b/celery_app/tasks.py @@ -0,0 +1,27 @@ +from celery_app.server import app +from celery_app.job_send import job_send + +# TODO: Write tasks that match our requirements + + +@app.task +def add(x, y): + broker_url = "localhost" + port = 5672 + username = "root" + password = "pass" + + res = x + y + job_send(broker_url, port, username, password, res) + + return res + + +@app.task +def mul(x, y): + return x * y + + +@app.task +def xsum(numbers): + return sum(numbers) diff --git a/discord_query.py b/discord_query.py new file mode 100644 index 0000000..6bb6fb8 --- /dev/null +++ b/discord_query.py @@ -0,0 +1,147 @@ +from retrievers.forum_summary_retriever import ( + ForumBasedSummaryRetriever, +) +from retrievers.process_dates import process_dates +from retrievers.utils.load_hyperparams import load_hyperparams +from tc_hivemind_backend.embeddings.cohere import CohereEmbedding +from tc_hivemind_backend.pg_vector_access import PGVectorAccess +from llama_index import QueryBundle +from llama_index.vector_stores import ExactMatchFilter, FilterCondition, MetadataFilters + + +def query_discord( + community_id: str, + query: str, + thread_names: list[str], + channel_names: list[str], + days: list[str], + similarity_top_k: int | None = None, +) -> str: + """ + query the discord database using filters given + and give an anwer to the given query using the LLM + + Parameters + ------------ + guild_id : str + the discord guild data to query + query : str + the query (question) of the user + thread_names : list[str] + the given threads to search for + channel_names : list[str] + the given channels to search for + days : list[str] + the given days to search for + similarity_top_k : int | None + the k similar results to use when querying the data + if `None` will load from `.env` file + + Returns + --------- + response : str + the LLM response given the query + """ + if similarity_top_k is None: + _, similarity_top_k, _ = load_hyperparams() + + table_name = "discord" + dbname = f"community_{community_id}" + + pg_vector = PGVectorAccess(table_name=table_name, dbname=dbname) + + index = pg_vector.load_index() + + thread_filters: list[ExactMatchFilter] = [] + channel_filters: list[ExactMatchFilter] = [] + day_filters: list[ExactMatchFilter] = [] + + for channel in channel_names: + channel_updated = channel.replace("'", "''") + channel_filters.append(ExactMatchFilter(key="channel", value=channel_updated)) + + for thread in thread_names: + thread_updated = thread.replace("'", "''") + thread_filters.append(ExactMatchFilter(key="thread", value=thread_updated)) + + for day in days: + day_filters.append(ExactMatchFilter(key="date", value=day)) + + all_filters: list[ExactMatchFilter] = [] + all_filters.extend(thread_filters) + all_filters.extend(channel_filters) + all_filters.extend(day_filters) + + filters = MetadataFilters(filters=all_filters, condition=FilterCondition.OR) + + query_engine = index.as_query_engine( + filters=filters, similarity_top_k=similarity_top_k + ) + + query_bundle = QueryBundle( + query_str=query, embedding=CohereEmbedding().get_text_embedding(text=query) + ) + response = query_engine.query(query_bundle) + + return response.response + + +def query_discord_auto_filter( + community_id: str, + query: str, + similarity_top_k: int | None = None, + d: int | None = None, +) -> str: + """ + get the query results and do the filtering automatically. + By automatically we mean, it would first query the summaries + to get the metadata filters + + Parameters + ----------- + guild_id : str + the discord guild data to query + query : str + the query (question) of the user + similarity_top_k : int | None + the value for the initial summary search + to get the `k2` count simliar nodes + if `None`, then would read from `.env` + d : int + this would make the secondary search (`query_discord`) + to be done on the `metadata.date - d` to `metadata.date + d` + + + Returns + --------- + response : str + the LLM response given the query + """ + table_name = "discord_summary" + dbname = f"community_{community_id}" + + if d is None: + _, _, d = load_hyperparams() + if similarity_top_k is None: + similarity_top_k, _, _ = load_hyperparams() + + discord_retriever = ForumBasedSummaryRetriever(table_name=table_name, dbname=dbname) + + channels, threads, dates = discord_retriever.retreive_metadata( + query=query, + metadata_group1_key="channel", + metadata_group2_key="thread", + metadata_date_key="date", + similarity_top_k=similarity_top_k, + ) + + dates_modified = process_dates(dates, d) + + response = query_discord( + community_id=community_id, + query=query, + thread_names=threads, + channel_names=channels, + days=dates_modified, + ) + return response diff --git a/docker-compose.example.yml b/docker-compose.example.yml new file mode 100644 index 0000000..0eccf2e --- /dev/null +++ b/docker-compose.example.yml @@ -0,0 +1,14 @@ +version: "3.9" + +services: + server: + build: + context: . + target: prod + dockerfile: Dockerfile + worker: + build: + context: . + target: prod + dockerfile: Dockerfile + command: python3 worker.py diff --git a/docker-compose.test.yml b/docker-compose.test.yml new file mode 100644 index 0000000..97fbcea --- /dev/null +++ b/docker-compose.test.yml @@ -0,0 +1,71 @@ +version: "3.9" + +services: + app: + build: + context: . + target: test + dockerfile: Dockerfile + environment: + - PORT=3000 + - MONGODB_HOST=mongo + - MONGODB_PORT=27017 + - MONGODB_USER=root + - MONGODB_PASS=pass + - NEO4J_PROTOCOL=bolt + - NEO4J_HOST=neo4j + - NEO4J_PORT=7687 + - NEO4J_USER=neo4j + - NEO4J_PASSWORD=password + - NEO4J_DB=neo4j + - POSTGRES_HOST=postgres + - POSTGRES_USER=root + - POSTGRES_PASS=pass + - POSTGRES_PORT=5432 + - CHUNK_SIZE=512 + - EMBEDDING_DIM=1024 + - K1_RETRIEVER_SEARCH=20 + - K2_RETRIEVER_SEARCH=5 + - D_RETRIEVER_SEARCH=7 + volumes: + - ./coverage:/project/coverage + depends_on: + neo4j: + condition: service_healthy + mongo: + condition: service_healthy + postgres: + condition: service_healthy + neo4j: + image: "neo4j:5.9.0" + environment: + - NEO4J_AUTH=neo4j/password + - NEO4J_PLUGINS=["apoc", "graph-data-science"] + - NEO4J_dbms_security_procedures_unrestricted=apoc.*,gds.* + healthcheck: + test: ["CMD" ,"wget", "http://localhost:7474"] + interval: 1m30s + timeout: 10s + retries: 2 + start_period: 40s + mongo: + image: "mongo:6.0.8" + environment: + - MONGO_INITDB_ROOT_USERNAME=root + - MONGO_INITDB_ROOT_PASSWORD=pass + healthcheck: + test: echo 'db.stats().ok' | mongosh localhost:27017/test --quiet + interval: 60s + timeout: 10s + retries: 2 + start_period: 40s + postgres: + image: "ankane/pgvector" + environment: + - POSTGRES_USER=root + - POSTGRES_PASSWORD=pass + healthcheck: + test: ["CMD-SHELL", "pg_isready"] + interval: 10s + timeout: 5s + retries: 5 diff --git a/docker-entrypoint.sh b/docker-entrypoint.sh new file mode 100644 index 0000000..5127573 --- /dev/null +++ b/docker-entrypoint.sh @@ -0,0 +1,3 @@ +#!/usr/bin/env bash +python3 -m coverage run --omit=tests/* -m pytest . +python3 -m coverage lcov -o coverage/lcov.info \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..c7886f5 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,20 @@ +numpy +llama-index>=0.9.21, <1.0.0 +pymongo +python-dotenv +pgvector +asyncpg +psycopg2-binary +sqlalchemy[asyncio] +async-sqlalchemy +python-pptx +tc-neo4j-lib +google-api-python-client +unstructured +cohere +neo4j>=5.14.1, <6.0.0 +coverage>=7.3.3, <8.0.0 +pytest>=7.4.3, <8.0.0 +python-dotenv==1.0.0 +tc_hivemind_backend==1.0.0 +celery>=5.3.6, <6.0.0 diff --git a/retrievers/__init__.py b/retrievers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/retrievers/forum_summary_retriever.py b/retrievers/forum_summary_retriever.py new file mode 100644 index 0000000..58b4d3e --- /dev/null +++ b/retrievers/forum_summary_retriever.py @@ -0,0 +1,74 @@ +from retrievers.summary_retriever_base import BaseSummarySearch +from tc_hivemind_backend.embeddings.cohere import CohereEmbedding + +from llama_index.embeddings import BaseEmbedding + + +class ForumBasedSummaryRetriever(BaseSummarySearch): + def __init__( + self, + table_name: str, + dbname: str, + embedding_model: BaseEmbedding | CohereEmbedding = CohereEmbedding(), + ) -> None: + """ + the class for forum based data like discord and discourse + by default CohereEmbedding will be used. + """ + super().__init__(table_name, dbname, embedding_model=embedding_model) + + def retreive_metadata( + self, + query: str, + metadata_group1_key: str, + metadata_group2_key: str, + metadata_date_key: str, + similarity_top_k: int = 20, + ) -> tuple[set[str], set[str], set[str]]: + """ + retrieve the metadata information of the similar nodes with the query + + Parameters + ----------- + query : str + the user query to process + metadata_group1_key : str + the conversations grouping type 1 + in discord can be `channel`, and in discourse can be `category` + metadata_group2_key : str + the conversations grouping type 2 + in discord can be `thread`, and in discourse can be `topic` + metadata_date_key : str + the daily metadata saved key + similarity_top_k : int + the top k nodes to get as the retriever. + default is set as 20 + + + Returns + --------- + group1_data : set[str] + the similar summary nodes having the group1_data. + can be an empty set meaning no similar thread + conversations for it was available. + group2_data : set[str] + the similar summary nodes having the group2_data. + can be an empty set meaning no similar channel + conversations for it was available. + dates : set[str] + the similar daily conversations to the given query + """ + nodes = self.get_similar_nodes(query=query, similarity_top_k=similarity_top_k) + + group1_data: set[str] = set() + dates: set[str] = set() + group2_data: set[str] = set() + + for node in nodes: + if node.metadata[metadata_group1_key]: + group1_data.add(node.metadata[metadata_group1_key]) + if node.metadata[metadata_group2_key]: + group2_data.add(node.metadata[metadata_group2_key]) + dates.add(node.metadata[metadata_date_key]) + + return group1_data, group2_data, dates diff --git a/retrievers/process_dates.py b/retrievers/process_dates.py new file mode 100644 index 0000000..dba3217 --- /dev/null +++ b/retrievers/process_dates.py @@ -0,0 +1,39 @@ +import logging +from datetime import timedelta + +from dateutil import parser + + +def process_dates(dates: list[str], d: int) -> list[str]: + """ + process the dates to be from `date - d` to `date + d` + + Parameters + ------------ + dates : list[str] + the list of dates given + d : int + to update the `dates` list to have `-d` and `+d` days + + + Returns + ---------- + dates_modified : list[str] + days added to it + """ + dates_modified: list[str] = [] + if dates != []: + lowest_date = min(parser.parse(date) for date in dates) + greatest_date = max(parser.parse(date) for date in dates) + + delta_days = timedelta(days=d) + + # the date condition + dt = lowest_date - delta_days + while dt <= greatest_date + delta_days: + dates_modified.append(dt.strftime("%Y-%m-%d")) + dt += timedelta(days=1) + else: + logging.warning("No dates given!") + + return dates_modified diff --git a/retrievers/summary_retriever_base.py b/retrievers/summary_retriever_base.py new file mode 100644 index 0000000..8cedca8 --- /dev/null +++ b/retrievers/summary_retriever_base.py @@ -0,0 +1,72 @@ +from tc_hivemind_backend.embeddings.cohere import CohereEmbedding + +from tc_hivemind_backend.pg_vector_access import PGVectorAccess +from llama_index import VectorStoreIndex +from llama_index.embeddings import BaseEmbedding +from llama_index.indices.query.schema import QueryBundle +from llama_index.schema import NodeWithScore + + +class BaseSummarySearch: + def __init__( + self, + table_name: str, + dbname: str, + embedding_model: BaseEmbedding = CohereEmbedding(), + ) -> None: + """ + initialize the base summary search class + + In this class we're doing a similarity search + for available saved nodes under postgresql + + Parameters + ------------- + table_name : str + the table that summary data is saved + *Note:* Don't include the `data_` prefix of the table, + cause lamma_index would original include that. + dbname : str + the database name to access + similarity_top_k : int + the top k nodes to get as the retriever. + default is set as 20 + embedding_model : llama_index.embeddings.BaseEmbedding + the embedding model to use for doing embedding on the query string + default would be CohereEmbedding that we've written + """ + self.index = self._setup_index(table_name, dbname) + self.embedding_model = embedding_model + + def get_similar_nodes( + self, query: str, similarity_top_k: int = 20 + ) -> list[NodeWithScore]: + """ + get k similar nodes to the query. + Note: this funciton wold get the embedding + for the query to do the similarity search. + + Parameters + ------------ + query : str + the user query to process + similarity_top_k : int + the top k nodes to get as the retriever. + default is set as 20 + """ + retriever = self.index.as_retriever(similarity_top_k=similarity_top_k) + + query_embedding = self.embedding_model.get_text_embedding(text=query) + + query_bundle = QueryBundle(query_str=query, embedding=query_embedding) + nodes = retriever._retrieve(query_bundle) + + return nodes + + def _setup_index(self, table_name: str, dbname: str) -> VectorStoreIndex: + """ + setup the llama_index VectorStoreIndex + """ + pg_vector_access = PGVectorAccess(table_name=table_name, dbname=dbname) + index = pg_vector_access.load_index() + return index diff --git a/retrievers/utils/__init__.py b/retrievers/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/retrievers/utils/load_hyperparams.py b/retrievers/utils/load_hyperparams.py new file mode 100644 index 0000000..98db6ce --- /dev/null +++ b/retrievers/utils/load_hyperparams.py @@ -0,0 +1,34 @@ +import os + +from dotenv import load_dotenv + + +def load_hyperparams() -> tuple[int, int, int]: + """ + load the k1, k2, and d hyperparams that are used for retrievers + + Returns + --------- + k1 : int + the value for the first summary search + to get the `k1` count similar nodes + k2 : int + the value for the secondary raw search + to get the `k2` count simliar nodes + d : int + the before and after day interval + """ + load_dotenv() + + k1 = os.getenv("K1_RETRIEVER_SEARCH") + k2 = os.getenv("K2_RETRIEVER_SEARCH") + d = os.getenv("D_RETRIEVER_SEARCH") + + if k1 is None: + raise ValueError("No `K1_RETRIEVER_SEARCH` available in .env file!") + if k2 is None: + raise ValueError("No `K2_RETRIEVER_SEARCH` available in .env file!") + if d is None: + raise ValueError("No `D_RETRIEVER_SEARCH` available in .env file!") + + return int(k1), int(k2), int(d) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/test_discord_summary_retriever.py b/tests/unit/test_discord_summary_retriever.py new file mode 100644 index 0000000..915742c --- /dev/null +++ b/tests/unit/test_discord_summary_retriever.py @@ -0,0 +1,84 @@ +from datetime import timedelta +from functools import partial +from unittest import TestCase +from unittest.mock import MagicMock + +from retrievers.forum_summary_retriever import ( + ForumBasedSummaryRetriever, +) +from dateutil import parser +from llama_index import Document, MockEmbedding, ServiceContext, VectorStoreIndex + + +class TestDiscordSummaryRetriever(TestCase): + def test_initialize_class(self): + ForumBasedSummaryRetriever._setup_index = MagicMock() + documents: list[Document] = [] + all_dates: list[str] = [] + + for i in range(30): + date = parser.parse("2023-08-01") + timedelta(days=i) + doc_date = date.strftime("%Y-%m-%d") + doc = Document( + text="SAMPLESAMPLESAMPLE", + metadata={ + "thread": f"thread{i % 5}", + "channel": f"channel{i % 3}", + "date": doc_date, + }, + ) + all_dates.append(doc_date) + documents.append(doc) + + mock_embedding_model = partial(MockEmbedding, embed_dim=1024) + + service_context = ServiceContext.from_defaults( + llm=None, embed_model=mock_embedding_model() + ) + ForumBasedSummaryRetriever._setup_index.return_value = ( + VectorStoreIndex.from_documents( + documents=[doc], service_context=service_context + ) + ) + + base_summary_search = ForumBasedSummaryRetriever( + table_name="sample", + dbname="sample", + embedding_model=mock_embedding_model(), + ) + channels, threads, dates = base_summary_search.retreive_metadata( + query="what is samplesample?", + similarity_top_k=5, + metadata_group1_key="channel", + metadata_group2_key="thread", + metadata_date_key="date", + ) + self.assertIsInstance(threads, set) + self.assertIsInstance(channels, set) + self.assertIsInstance(dates, set) + + self.assertTrue( + threads.issubset( + set( + [ + "thread0", + "thread1", + "thread2", + "thread3", + "thread4", + ] + ) + ) + ) + self.assertTrue( + channels.issubset( + set( + [ + "channel0", + "channel1", + "channel2", + ] + ) + ) + ) + self.assertTrue(dates.issubset(all_dates)) diff --git a/tests/unit/test_load_retriever_hyperparameters.py b/tests/unit/test_load_retriever_hyperparameters.py new file mode 100644 index 0000000..eadcbdc --- /dev/null +++ b/tests/unit/test_load_retriever_hyperparameters.py @@ -0,0 +1,73 @@ +import unittest +from unittest.mock import patch + +from retrievers.utils.load_hyperparams import load_hyperparams + + +class TestLoadHyperparams(unittest.TestCase): + @patch("os.getenv") + def test_valid_hyperparams(self, mock_getenv): + mock_getenv.side_effect = lambda x: { + "K1_RETRIEVER_SEARCH": "10", + "K2_RETRIEVER_SEARCH": "20", + "D_RETRIEVER_SEARCH": "30", + }.get(x) + result = load_hyperparams() + self.assertEqual(result, (10, 20, 30)) + + @patch("os.getenv") + def test_missing_k1(self, mock_getenv): + mock_getenv.side_effect = lambda x: { + "K2_RETRIEVER_SEARCH": "20", + "D_RETRIEVER_SEARCH": "30", + }.get(x) + with self.assertRaises(ValueError): + load_hyperparams() + + @patch("os.getenv") + def test_missing_k2(self, mock_getenv): + mock_getenv.side_effect = lambda x: { + "K1_RETRIEVER_SEARCH": "10", + "D_RETRIEVER_SEARCH": "30", + }.get(x) + with self.assertRaises(ValueError): + load_hyperparams() + + @patch("os.getenv") + def test_missing_d(self, mock_getenv): + mock_getenv.side_effect = lambda x: { + "K1_RETRIEVER_SEARCH": "10", + "K2_RETRIEVER_SEARCH": "20", + }.get(x) + with self.assertRaises(ValueError): + load_hyperparams() + + @patch("os.getenv") + def test_invalid_k1(self, mock_getenv): + mock_getenv.side_effect = lambda x: { + "K1_RETRIEVER_SEARCH": "invalid", + "K2_RETRIEVER_SEARCH": "20", + "D_RETRIEVER_SEARCH": "30", + }.get(x) + with self.assertRaises(ValueError): + load_hyperparams() + + @patch("os.getenv") + def test_invalid_k2(self, mock_getenv): + mock_getenv.side_effect = lambda x: { + "K1_RETRIEVER_SEARCH": "10", + "K2_RETRIEVER_SEARCH": "invalid", + "D_RETRIEVER_SEARCH": "30", + }.get(x) + with self.assertRaises(ValueError): + load_hyperparams() + + @patch("os.getenv") + def test_invalid_d(self, mock_getenv): + mock_getenv.side_effect = lambda x: { + "K1_RETRIEVER_SEARCH": "10", + "K2_RETRIEVER_SEARCH": "20", + "D_RETRIEVER_SEARCH": "invalid", + }.get(x) + with self.assertRaises(ValueError): + load_hyperparams() diff --git a/tests/unit/test_process_dates_forum_retriever_search.py b/tests/unit/test_process_dates_forum_retriever_search.py new file mode 100644 index 0000000..f580c82 --- /dev/null +++ b/tests/unit/test_process_dates_forum_retriever_search.py @@ -0,0 +1,42 @@ +import unittest + +from retrievers.process_dates import process_dates + + +class TestProcessDates(unittest.TestCase): + def test_process_dates_with_valid_input(self): + # Test with a valid input + input_dates = ["2023-01-01", "2023-01-03", "2023-01-05"] + d = 2 + expected_output = [ + "2022-12-30", + "2022-12-31", + "2023-01-01", + "2023-01-02", + "2023-01-03", + "2023-01-04", + "2023-01-05", + "2023-01-06", + "2023-01-07", + ] + self.assertEqual(process_dates(input_dates, d), expected_output) + + def test_process_dates_with_empty_input(self): + # Test with an empty input + input_dates = [] + d = 2 + expected_output = [] + self.assertEqual(process_dates(input_dates, d), expected_output) + + def test_process_dates_with_single_date(self): + # Test with a single date in the input + input_dates = ["2023-01-01"] + d = 2 + expected_output = [ + "2022-12-30", + "2022-12-31", + "2023-01-01", + "2023-01-02", + "2023-01-03", + ] + self.assertEqual(process_dates(input_dates, d), expected_output) diff --git a/tests/unit/test_summary_retriever_base.py b/tests/unit/test_summary_retriever_base.py new file mode 100644 index 0000000..f630d94 --- /dev/null +++ b/tests/unit/test_summary_retriever_base.py @@ -0,0 +1,30 @@ +from functools import partial +from unittest import TestCase +from unittest.mock import MagicMock + +from retrievers.summary_retriever_base import BaseSummarySearch +from llama_index import Document, MockEmbedding, ServiceContext, VectorStoreIndex +from llama_index.schema import NodeWithScore + + +class TestSummaryRetrieverBase(TestCase): + def test_initialize_class(self): + BaseSummarySearch._setup_index = MagicMock() + doc = Document(text="SAMPLESAMPLESAMPLE") + mock_embedding_model = partial(MockEmbedding, embed_dim=1024) + + service_context = ServiceContext.from_defaults( + llm=None, embed_model=mock_embedding_model() + ) + BaseSummarySearch._setup_index.return_value = VectorStoreIndex.from_documents( + documents=[doc], service_context=service_context + ) + + base_summary_search = BaseSummarySearch( + table_name="sample", + dbname="sample", + embedding_model=mock_embedding_model(), + ) + nodes = base_summary_search.get_similar_nodes(query="what is samplesample?") + self.assertIsInstance(nodes, list) + self.assertIsInstance(nodes[0], NodeWithScore) diff --git a/worker.py b/worker.py new file mode 100644 index 0000000..dcea119 --- /dev/null +++ b/worker.py @@ -0,0 +1,39 @@ +from tc_messageBroker import RabbitMQ +from tc_messageBroker.rabbit_mq.event import Event +from tc_messageBroker.rabbit_mq.queue import Queue + +from celery_app.tasks import add + + +# TODO: Update according to our requirements +def do_something(recieved_data): + message = f"Calculation Results:" + print(message) + print(f"recieved_data: {recieved_data}") + add.delay(20, 14) + + +def job_recieve(broker_url, port, username, password): + rabbit_mq = RabbitMQ( + broker_url=broker_url, port=port, username=username, password=password + ) + + # TODO: Update according to our requirements + rabbit_mq.on_event(Event.HIVEMIND.INTERACTION_CREATED, do_something) + rabbit_mq.connect(Queue.HIVEMIND) + rabbit_mq.consume(Queue.HIVEMIND) + + if rabbit_mq.channel is not None: + rabbit_mq.channel.start_consuming() + else: + print("Connection to broker was not successful!") + + +if __name__ == "__main__": + # TODO: read from .env + broker_url = "localhost" + port = 5672 + username = "root" + password = "pass" + + job_recieve(broker_url, port, username, password) From e94bc8aec15c677f246d9755266753d51e255842 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Thu, 28 Dec 2023 16:44:02 +0330 Subject: [PATCH 02/25] fix: linter issues based on superlinter rules! --- celery_app/tasks.py | 2 +- discord_query.py | 12 +++++------- retrievers/forum_summary_retriever.py | 3 +-- retrievers/summary_retriever_base.py | 5 ++--- tests/unit/test_discord_summary_retriever.py | 4 +--- tests/unit/test_summary_retriever_base.py | 2 +- worker.py | 5 ++--- 7 files changed, 13 insertions(+), 20 deletions(-) diff --git a/celery_app/tasks.py b/celery_app/tasks.py index 73652a7..e2025af 100644 --- a/celery_app/tasks.py +++ b/celery_app/tasks.py @@ -1,5 +1,5 @@ -from celery_app.server import app from celery_app.job_send import job_send +from celery_app.server import app # TODO: Write tasks that match our requirements diff --git a/discord_query.py b/discord_query.py index 6bb6fb8..d65829a 100644 --- a/discord_query.py +++ b/discord_query.py @@ -1,12 +1,10 @@ -from retrievers.forum_summary_retriever import ( - ForumBasedSummaryRetriever, -) +from llama_index import QueryBundle +from llama_index.vector_stores import ExactMatchFilter, FilterCondition, MetadataFilters +from retrievers.forum_summary_retriever import ForumBasedSummaryRetriever from retrievers.process_dates import process_dates from retrievers.utils.load_hyperparams import load_hyperparams from tc_hivemind_backend.embeddings.cohere import CohereEmbedding from tc_hivemind_backend.pg_vector_access import PGVectorAccess -from llama_index import QueryBundle -from llama_index.vector_stores import ExactMatchFilter, FilterCondition, MetadataFilters def query_discord( @@ -140,8 +138,8 @@ def query_discord_auto_filter( response = query_discord( community_id=community_id, query=query, - thread_names=threads, - channel_names=channels, + thread_names=list(threads), + channel_names=list(channels), days=dates_modified, ) return response diff --git a/retrievers/forum_summary_retriever.py b/retrievers/forum_summary_retriever.py index 58b4d3e..8a3c16c 100644 --- a/retrievers/forum_summary_retriever.py +++ b/retrievers/forum_summary_retriever.py @@ -1,8 +1,7 @@ +from llama_index.embeddings import BaseEmbedding from retrievers.summary_retriever_base import BaseSummarySearch from tc_hivemind_backend.embeddings.cohere import CohereEmbedding -from llama_index.embeddings import BaseEmbedding - class ForumBasedSummaryRetriever(BaseSummarySearch): def __init__( diff --git a/retrievers/summary_retriever_base.py b/retrievers/summary_retriever_base.py index 8cedca8..1cc3420 100644 --- a/retrievers/summary_retriever_base.py +++ b/retrievers/summary_retriever_base.py @@ -1,10 +1,9 @@ -from tc_hivemind_backend.embeddings.cohere import CohereEmbedding - -from tc_hivemind_backend.pg_vector_access import PGVectorAccess from llama_index import VectorStoreIndex from llama_index.embeddings import BaseEmbedding from llama_index.indices.query.schema import QueryBundle from llama_index.schema import NodeWithScore +from tc_hivemind_backend.embeddings.cohere import CohereEmbedding +from tc_hivemind_backend.pg_vector_access import PGVectorAccess class BaseSummarySearch: diff --git a/tests/unit/test_discord_summary_retriever.py b/tests/unit/test_discord_summary_retriever.py index 915742c..6c73bc3 100644 --- a/tests/unit/test_discord_summary_retriever.py +++ b/tests/unit/test_discord_summary_retriever.py @@ -3,11 +3,9 @@ from unittest import TestCase from unittest.mock import MagicMock -from retrievers.forum_summary_retriever import ( - ForumBasedSummaryRetriever, -) from dateutil import parser from llama_index import Document, MockEmbedding, ServiceContext, VectorStoreIndex +from retrievers.forum_summary_retriever import ForumBasedSummaryRetriever class TestDiscordSummaryRetriever(TestCase): diff --git a/tests/unit/test_summary_retriever_base.py b/tests/unit/test_summary_retriever_base.py index f630d94..0575a4b 100644 --- a/tests/unit/test_summary_retriever_base.py +++ b/tests/unit/test_summary_retriever_base.py @@ -2,9 +2,9 @@ from unittest import TestCase from unittest.mock import MagicMock -from retrievers.summary_retriever_base import BaseSummarySearch from llama_index import Document, MockEmbedding, ServiceContext, VectorStoreIndex from llama_index.schema import NodeWithScore +from retrievers.summary_retriever_base import BaseSummarySearch class TestSummaryRetrieverBase(TestCase): diff --git a/worker.py b/worker.py index dcea119..fca2379 100644 --- a/worker.py +++ b/worker.py @@ -1,13 +1,12 @@ +from celery_app.tasks import add from tc_messageBroker import RabbitMQ from tc_messageBroker.rabbit_mq.event import Event from tc_messageBroker.rabbit_mq.queue import Queue -from celery_app.tasks import add - # TODO: Update according to our requirements def do_something(recieved_data): - message = f"Calculation Results:" + message = "Calculation Results:" print(message) print(f"recieved_data: {recieved_data}") add.delay(20, 14) From 41fb2e4662b17fbe7178efd94bb57f26435cad24 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Thu, 28 Dec 2023 16:48:21 +0330 Subject: [PATCH 03/25] fix: mypy linter error! --- discord_query.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/discord_query.py b/discord_query.py index d65829a..cfd9cb3 100644 --- a/discord_query.py +++ b/discord_query.py @@ -133,7 +133,7 @@ def query_discord_auto_filter( similarity_top_k=similarity_top_k, ) - dates_modified = process_dates(dates, d) + dates_modified = process_dates(list(dates), d) response = query_discord( community_id=community_id, From 0d6e038471d61ec0932810a34c31ea3f3bf79006 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Thu, 28 Dec 2023 17:09:57 +0330 Subject: [PATCH 04/25] update: code restructuring! --- {retrievers => bot}/__init__.py | 0 {retrievers/utils => bot/retrievers}/__init__.py | 0 {retrievers => bot/retrievers}/forum_summary_retriever.py | 0 {retrievers => bot/retrievers}/process_dates.py | 0 {retrievers => bot/retrievers}/summary_retriever_base.py | 0 bot/retrievers/utils/__init__.py | 0 {retrievers => bot/retrievers}/utils/load_hyperparams.py | 0 7 files changed, 0 insertions(+), 0 deletions(-) rename {retrievers => bot}/__init__.py (100%) rename {retrievers/utils => bot/retrievers}/__init__.py (100%) rename {retrievers => bot/retrievers}/forum_summary_retriever.py (100%) rename {retrievers => bot/retrievers}/process_dates.py (100%) rename {retrievers => bot/retrievers}/summary_retriever_base.py (100%) create mode 100644 bot/retrievers/utils/__init__.py rename {retrievers => bot/retrievers}/utils/load_hyperparams.py (100%) diff --git a/retrievers/__init__.py b/bot/__init__.py similarity index 100% rename from retrievers/__init__.py rename to bot/__init__.py diff --git a/retrievers/utils/__init__.py b/bot/retrievers/__init__.py similarity index 100% rename from retrievers/utils/__init__.py rename to bot/retrievers/__init__.py diff --git a/retrievers/forum_summary_retriever.py b/bot/retrievers/forum_summary_retriever.py similarity index 100% rename from retrievers/forum_summary_retriever.py rename to bot/retrievers/forum_summary_retriever.py diff --git a/retrievers/process_dates.py b/bot/retrievers/process_dates.py similarity index 100% rename from retrievers/process_dates.py rename to bot/retrievers/process_dates.py diff --git a/retrievers/summary_retriever_base.py b/bot/retrievers/summary_retriever_base.py similarity index 100% rename from retrievers/summary_retriever_base.py rename to bot/retrievers/summary_retriever_base.py diff --git a/bot/retrievers/utils/__init__.py b/bot/retrievers/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/retrievers/utils/load_hyperparams.py b/bot/retrievers/utils/load_hyperparams.py similarity index 100% rename from retrievers/utils/load_hyperparams.py rename to bot/retrievers/utils/load_hyperparams.py From f4c9980f2e010a03c6f0526d8a50594b2aa43f79 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Thu, 28 Dec 2023 17:19:14 +0330 Subject: [PATCH 05/25] feat: update imports based on code restructuring! --- bot/retrievers/forum_summary_retriever.py | 2 +- discord_query.py | 6 +++--- tests/unit/test_discord_summary_retriever.py | 2 +- tests/unit/test_process_dates_forum_retriever_search.py | 2 +- tests/unit/test_summary_retriever_base.py | 2 +- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/bot/retrievers/forum_summary_retriever.py b/bot/retrievers/forum_summary_retriever.py index 8a3c16c..75df242 100644 --- a/bot/retrievers/forum_summary_retriever.py +++ b/bot/retrievers/forum_summary_retriever.py @@ -1,5 +1,5 @@ from llama_index.embeddings import BaseEmbedding -from retrievers.summary_retriever_base import BaseSummarySearch +from bot.retrievers.summary_retriever_base import BaseSummarySearch from tc_hivemind_backend.embeddings.cohere import CohereEmbedding diff --git a/discord_query.py b/discord_query.py index cfd9cb3..c0d9b37 100644 --- a/discord_query.py +++ b/discord_query.py @@ -1,8 +1,8 @@ from llama_index import QueryBundle from llama_index.vector_stores import ExactMatchFilter, FilterCondition, MetadataFilters -from retrievers.forum_summary_retriever import ForumBasedSummaryRetriever -from retrievers.process_dates import process_dates -from retrievers.utils.load_hyperparams import load_hyperparams +from bot.retrievers.forum_summary_retriever import ForumBasedSummaryRetriever +from bot.retrievers.process_dates import process_dates +from bot.retrievers.utils.load_hyperparams import load_hyperparams from tc_hivemind_backend.embeddings.cohere import CohereEmbedding from tc_hivemind_backend.pg_vector_access import PGVectorAccess diff --git a/tests/unit/test_discord_summary_retriever.py b/tests/unit/test_discord_summary_retriever.py index 6c73bc3..92448c4 100644 --- a/tests/unit/test_discord_summary_retriever.py +++ b/tests/unit/test_discord_summary_retriever.py @@ -5,7 +5,7 @@ from dateutil import parser from llama_index import Document, MockEmbedding, ServiceContext, VectorStoreIndex -from retrievers.forum_summary_retriever import ForumBasedSummaryRetriever +from bot.retrievers.forum_summary_retriever import ForumBasedSummaryRetriever class TestDiscordSummaryRetriever(TestCase): diff --git a/tests/unit/test_process_dates_forum_retriever_search.py b/tests/unit/test_process_dates_forum_retriever_search.py index f580c82..9b44f3c 100644 --- a/tests/unit/test_process_dates_forum_retriever_search.py +++ b/tests/unit/test_process_dates_forum_retriever_search.py @@ -1,6 +1,6 @@ import unittest -from retrievers.process_dates import process_dates +from bot.retrievers.process_dates import process_dates class TestProcessDates(unittest.TestCase): diff --git a/tests/unit/test_summary_retriever_base.py b/tests/unit/test_summary_retriever_base.py index 0575a4b..3f027db 100644 --- a/tests/unit/test_summary_retriever_base.py +++ b/tests/unit/test_summary_retriever_base.py @@ -4,7 +4,7 @@ from llama_index import Document, MockEmbedding, ServiceContext, VectorStoreIndex from llama_index.schema import NodeWithScore -from retrievers.summary_retriever_base import BaseSummarySearch +from bot.retrievers.summary_retriever_base import BaseSummarySearch class TestSummaryRetrieverBase(TestCase): From 1c9790fd7bd93adebd1716e289893faf99bea909 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Thu, 28 Dec 2023 17:26:19 +0330 Subject: [PATCH 06/25] feat: Added rabbitmq and .env variable reading! --- celery_app/server.py | 11 +++- celery_app/tasks.py | 10 ++-- docker-compose.test.yml | 15 +++++ tests/unit/test_load_credentials.py | 41 ++++++++++++++ .../test_load_retriever_hyperparameters.py | 2 +- utils/__init__.py | 0 utils/credentials.py | 55 +++++++++++++++++++ worker.py | 12 ++-- 8 files changed, 134 insertions(+), 12 deletions(-) create mode 100644 tests/unit/test_load_credentials.py create mode 100644 utils/__init__.py create mode 100644 utils/credentials.py diff --git a/celery_app/server.py b/celery_app/server.py index c1c44d4..e9e2743 100644 --- a/celery_app/server.py +++ b/celery_app/server.py @@ -1,5 +1,12 @@ from celery import Celery -# TODO: read from .env -app = Celery("celery_app/tasks", broker="pyamqp://root:pass@localhost//") +from utils.credentials import load_rabbitmq_credentials + +rabbit_creds = load_rabbitmq_credentials() +user = rabbit_creds['user'] +password = rabbit_creds['password'] +host = rabbit_creds['host'] +port = rabbit_creds['port'] + +app = Celery("celery_app/tasks", broker=f"pyamqp://{user}:{password}@{host}:{port}//") app.autodiscover_tasks(["celery_app"]) diff --git a/celery_app/tasks.py b/celery_app/tasks.py index e2025af..b4fea92 100644 --- a/celery_app/tasks.py +++ b/celery_app/tasks.py @@ -1,15 +1,17 @@ from celery_app.job_send import job_send from celery_app.server import app +from utils.credentials import load_rabbitmq_credentials # TODO: Write tasks that match our requirements @app.task def add(x, y): - broker_url = "localhost" - port = 5672 - username = "root" - password = "pass" + rabbit_creds = load_rabbitmq_credentials() + username = rabbit_creds['user'] + password = rabbit_creds['password'] + broker_url = rabbit_creds['host'] + port = rabbit_creds['port'] res = x + y job_send(broker_url, port, username, password, res) diff --git a/docker-compose.test.yml b/docker-compose.test.yml index 97fbcea..4e046d9 100644 --- a/docker-compose.test.yml +++ b/docker-compose.test.yml @@ -22,6 +22,10 @@ services: - POSTGRES_USER=root - POSTGRES_PASS=pass - POSTGRES_PORT=5432 + - RABBIT_HOST=rabbitmq + - RABBIT_PORT=5672 + - RABBIT_USER=root + - RABBIT_PASSWORD=pass - CHUNK_SIZE=512 - EMBEDDING_DIM=1024 - K1_RETRIEVER_SEARCH=20 @@ -69,3 +73,14 @@ services: interval: 10s timeout: 5s retries: 5 + rabbitmq: + image: "rabbitmq:3-management-alpine" + environment: + - RABBITMQ_DEFAULT_USER=root + - RABBITMQ_DEFAULT_PASS=pass + healthcheck: + test: rabbitmq-diagnostics -q ping + interval: 30s + timeout: 30s + retries: 2 + start_period: 40s diff --git a/tests/unit/test_load_credentials.py b/tests/unit/test_load_credentials.py new file mode 100644 index 0000000..701b5e6 --- /dev/null +++ b/tests/unit/test_load_credentials.py @@ -0,0 +1,41 @@ +import unittest + +from utils.credentials import load_postgres_credentials, load_rabbitmq_credentials + + +class TestCredentialLoadings(unittest.TestCase): + def test_postgresql_envs_check_type(self): + postgres_creds = load_postgres_credentials() + + self.assertIsInstance(postgres_creds, dict) + + def test_postgresql_envs_values(self): + postgres_creds = load_postgres_credentials() + + self.assertNotEqual(postgres_creds["user"], None) + self.assertNotEqual(postgres_creds["password"], None) + self.assertNotEqual(postgres_creds["host"], None) + self.assertNotEqual(postgres_creds["port"], None) + + self.assertIsInstance(postgres_creds["user"], str) + self.assertIsInstance(postgres_creds["password"], str) + self.assertIsInstance(postgres_creds["host"], str) + self.assertIsInstance(postgres_creds["port"], str) + + def test_rabbitmq_envs_check_type(self): + rabbitmq_creds = load_rabbitmq_credentials() + + self.assertIsInstance(rabbitmq_creds, dict) + + def test_rabbitmq_envs_values(self): + rabbitmq_creds = load_postgres_credentials() + + self.assertNotEqual(rabbitmq_creds["user"], None) + self.assertNotEqual(rabbitmq_creds["password"], None) + self.assertNotEqual(rabbitmq_creds["host"], None) + self.assertNotEqual(rabbitmq_creds["port"], None) + + self.assertIsInstance(rabbitmq_creds["user"], str) + self.assertIsInstance(rabbitmq_creds["password"], str) + self.assertIsInstance(rabbitmq_creds["host"], str) + self.assertIsInstance(rabbitmq_creds["port"], str) \ No newline at end of file diff --git a/tests/unit/test_load_retriever_hyperparameters.py b/tests/unit/test_load_retriever_hyperparameters.py index eadcbdc..1f9c2fa 100644 --- a/tests/unit/test_load_retriever_hyperparameters.py +++ b/tests/unit/test_load_retriever_hyperparameters.py @@ -1,7 +1,7 @@ import unittest from unittest.mock import patch -from retrievers.utils.load_hyperparams import load_hyperparams +from bot.retrievers.utils.load_hyperparams import load_hyperparams class TestLoadHyperparams(unittest.TestCase): diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/credentials.py b/utils/credentials.py new file mode 100644 index 0000000..02b3fb7 --- /dev/null +++ b/utils/credentials.py @@ -0,0 +1,55 @@ +import os + +from dotenv import load_dotenv + + +def load_postgres_credentials() -> dict[str, str]: + """ + load posgresql db credentials from .env + + Returns: + --------- + postgres_creds : dict[str, Any] + postgresql credentials + a dictionary representive of + `user`: str + `password` : str + `host` : str + `port` : int + """ + load_dotenv() + + postgres_creds = {} + + postgres_creds["user"] = os.getenv("POSTGRES_USER", "") + postgres_creds["password"] = os.getenv("POSTGRES_PASS", "") + postgres_creds["host"] = os.getenv("POSTGRES_HOST", "") + postgres_creds["port"] = os.getenv("POSTGRES_PORT", "") + + return postgres_creds + + +def load_rabbitmq_credentials() -> dict[str, str]: + """ + load rabbitmq credentials from .env + + Returns: + --------- + rabbitmq_creds : dict[str, Any] + rabbitmq credentials + a dictionary representive of + `user`: str + `password` : str + `host` : str + `port` : int + """ + load_dotenv() + + rabbitmq_creds = {} + + rabbitmq_creds["user"] = os.getenv("RABBIT_USER", "") + rabbitmq_creds["password"] = os.getenv("RABBIT_PASSWORD", "") + rabbitmq_creds["host"] = os.getenv("RABBIT_HOST", "") + rabbitmq_creds["port"] = os.getenv("RABBIT_PORT", "") + + return rabbitmq_creds \ No newline at end of file diff --git a/worker.py b/worker.py index fca2379..1104bbd 100644 --- a/worker.py +++ b/worker.py @@ -3,6 +3,8 @@ from tc_messageBroker.rabbit_mq.event import Event from tc_messageBroker.rabbit_mq.queue import Queue +from utils.credentials import load_rabbitmq_credentials + # TODO: Update according to our requirements def do_something(recieved_data): @@ -29,10 +31,10 @@ def job_recieve(broker_url, port, username, password): if __name__ == "__main__": - # TODO: read from .env - broker_url = "localhost" - port = 5672 - username = "root" - password = "pass" + rabbit_creds = load_rabbitmq_credentials() + username = rabbit_creds['user'] + password = rabbit_creds['password'] + broker_url = rabbit_creds['host'] + port = rabbit_creds['port'] job_recieve(broker_url, port, username, password) From 655df709d049e58d5bc959cb660cd9a96e3a1735 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Thu, 28 Dec 2023 17:33:47 +0330 Subject: [PATCH 07/25] fix: linter issues based on superlinter rules! --- celery_app/server.py | 9 ++++----- celery_app/tasks.py | 9 +++++---- discord_query.py | 4 ++-- tests/unit/test_discord_summary_retriever.py | 2 +- tests/unit/test_load_credentials.py | 2 +- tests/unit/test_summary_retriever_base.py | 2 +- utils/credentials.py | 2 +- worker.py | 9 ++++----- 8 files changed, 19 insertions(+), 20 deletions(-) diff --git a/celery_app/server.py b/celery_app/server.py index e9e2743..499aff0 100644 --- a/celery_app/server.py +++ b/celery_app/server.py @@ -1,12 +1,11 @@ from celery import Celery - from utils.credentials import load_rabbitmq_credentials rabbit_creds = load_rabbitmq_credentials() -user = rabbit_creds['user'] -password = rabbit_creds['password'] -host = rabbit_creds['host'] -port = rabbit_creds['port'] +user = rabbit_creds["user"] +password = rabbit_creds["password"] +host = rabbit_creds["host"] +port = rabbit_creds["port"] app = Celery("celery_app/tasks", broker=f"pyamqp://{user}:{password}@{host}:{port}//") app.autodiscover_tasks(["celery_app"]) diff --git a/celery_app/tasks.py b/celery_app/tasks.py index b4fea92..ccd4c5e 100644 --- a/celery_app/tasks.py +++ b/celery_app/tasks.py @@ -2,16 +2,17 @@ from celery_app.server import app from utils.credentials import load_rabbitmq_credentials + # TODO: Write tasks that match our requirements @app.task def add(x, y): rabbit_creds = load_rabbitmq_credentials() - username = rabbit_creds['user'] - password = rabbit_creds['password'] - broker_url = rabbit_creds['host'] - port = rabbit_creds['port'] + username = rabbit_creds["user"] + password = rabbit_creds["password"] + broker_url = rabbit_creds["host"] + port = rabbit_creds["port"] res = x + y job_send(broker_url, port, username, password, res) diff --git a/discord_query.py b/discord_query.py index c0d9b37..ae0865a 100644 --- a/discord_query.py +++ b/discord_query.py @@ -1,8 +1,8 @@ -from llama_index import QueryBundle -from llama_index.vector_stores import ExactMatchFilter, FilterCondition, MetadataFilters from bot.retrievers.forum_summary_retriever import ForumBasedSummaryRetriever from bot.retrievers.process_dates import process_dates from bot.retrievers.utils.load_hyperparams import load_hyperparams +from llama_index import QueryBundle +from llama_index.vector_stores import ExactMatchFilter, FilterCondition, MetadataFilters from tc_hivemind_backend.embeddings.cohere import CohereEmbedding from tc_hivemind_backend.pg_vector_access import PGVectorAccess diff --git a/tests/unit/test_discord_summary_retriever.py b/tests/unit/test_discord_summary_retriever.py index 92448c4..d5fafa3 100644 --- a/tests/unit/test_discord_summary_retriever.py +++ b/tests/unit/test_discord_summary_retriever.py @@ -3,9 +3,9 @@ from unittest import TestCase from unittest.mock import MagicMock +from bot.retrievers.forum_summary_retriever import ForumBasedSummaryRetriever from dateutil import parser from llama_index import Document, MockEmbedding, ServiceContext, VectorStoreIndex -from bot.retrievers.forum_summary_retriever import ForumBasedSummaryRetriever class TestDiscordSummaryRetriever(TestCase): diff --git a/tests/unit/test_load_credentials.py b/tests/unit/test_load_credentials.py index 701b5e6..5bfa795 100644 --- a/tests/unit/test_load_credentials.py +++ b/tests/unit/test_load_credentials.py @@ -38,4 +38,4 @@ def test_rabbitmq_envs_values(self): self.assertIsInstance(rabbitmq_creds["user"], str) self.assertIsInstance(rabbitmq_creds["password"], str) self.assertIsInstance(rabbitmq_creds["host"], str) - self.assertIsInstance(rabbitmq_creds["port"], str) \ No newline at end of file + self.assertIsInstance(rabbitmq_creds["port"], str) diff --git a/tests/unit/test_summary_retriever_base.py b/tests/unit/test_summary_retriever_base.py index 3f027db..14180ac 100644 --- a/tests/unit/test_summary_retriever_base.py +++ b/tests/unit/test_summary_retriever_base.py @@ -2,9 +2,9 @@ from unittest import TestCase from unittest.mock import MagicMock +from bot.retrievers.summary_retriever_base import BaseSummarySearch from llama_index import Document, MockEmbedding, ServiceContext, VectorStoreIndex from llama_index.schema import NodeWithScore -from bot.retrievers.summary_retriever_base import BaseSummarySearch class TestSummaryRetrieverBase(TestCase): diff --git a/utils/credentials.py b/utils/credentials.py index 02b3fb7..234c82c 100644 --- a/utils/credentials.py +++ b/utils/credentials.py @@ -52,4 +52,4 @@ def load_rabbitmq_credentials() -> dict[str, str]: rabbitmq_creds["host"] = os.getenv("RABBIT_HOST", "") rabbitmq_creds["port"] = os.getenv("RABBIT_PORT", "") - return rabbitmq_creds \ No newline at end of file + return rabbitmq_creds diff --git a/worker.py b/worker.py index 1104bbd..48ddb05 100644 --- a/worker.py +++ b/worker.py @@ -2,7 +2,6 @@ from tc_messageBroker import RabbitMQ from tc_messageBroker.rabbit_mq.event import Event from tc_messageBroker.rabbit_mq.queue import Queue - from utils.credentials import load_rabbitmq_credentials @@ -32,9 +31,9 @@ def job_recieve(broker_url, port, username, password): if __name__ == "__main__": rabbit_creds = load_rabbitmq_credentials() - username = rabbit_creds['user'] - password = rabbit_creds['password'] - broker_url = rabbit_creds['host'] - port = rabbit_creds['port'] + username = rabbit_creds["user"] + password = rabbit_creds["password"] + broker_url = rabbit_creds["host"] + port = rabbit_creds["port"] job_recieve(broker_url, port, username, password) From 130553308b9e8c636d87c3141832433fbb396c03 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Thu, 28 Dec 2023 17:38:09 +0330 Subject: [PATCH 08/25] fix: isort linter issues! --- bot/retrievers/forum_summary_retriever.py | 2 +- celery_app/tasks.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/bot/retrievers/forum_summary_retriever.py b/bot/retrievers/forum_summary_retriever.py index 75df242..1e04cea 100644 --- a/bot/retrievers/forum_summary_retriever.py +++ b/bot/retrievers/forum_summary_retriever.py @@ -1,5 +1,5 @@ -from llama_index.embeddings import BaseEmbedding from bot.retrievers.summary_retriever_base import BaseSummarySearch +from llama_index.embeddings import BaseEmbedding from tc_hivemind_backend.embeddings.cohere import CohereEmbedding diff --git a/celery_app/tasks.py b/celery_app/tasks.py index ccd4c5e..5dd154e 100644 --- a/celery_app/tasks.py +++ b/celery_app/tasks.py @@ -1,6 +1,5 @@ from celery_app.job_send import job_send from celery_app.server import app - from utils.credentials import load_rabbitmq_credentials # TODO: Write tasks that match our requirements From ab04e8bd299d7602c119021196c2ecd41f71bb1f Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Sun, 31 Dec 2023 11:32:08 +0330 Subject: [PATCH 09/25] update: shared codes library! --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index c7886f5..f04f1e7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,5 +16,5 @@ neo4j>=5.14.1, <6.0.0 coverage>=7.3.3, <8.0.0 pytest>=7.4.3, <8.0.0 python-dotenv==1.0.0 -tc_hivemind_backend==1.0.0 +tc-hivemind-backend==1.0.0 celery>=5.3.6, <6.0.0 From c38c2f2af93c9b56c086427b6295e50a6e7741d4 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Mon, 1 Jan 2024 12:42:07 +0330 Subject: [PATCH 10/25] feat: seperate the query engine! --- .gitignore | 3 ++- discord_query.py | 65 ++++++++++++++++++++++++++++++++---------------- 2 files changed, 45 insertions(+), 23 deletions(-) diff --git a/.gitignore b/.gitignore index 1cd6533..3a5bd6b 100644 --- a/.gitignore +++ b/.gitignore @@ -159,4 +159,5 @@ cython_debug/ # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ -hivemind-bot-env/* \ No newline at end of file +hivemind-bot-env/* +main.ipynb \ No newline at end of file diff --git a/discord_query.py b/discord_query.py index ae0865a..aedfe79 100644 --- a/discord_query.py +++ b/discord_query.py @@ -2,19 +2,19 @@ from bot.retrievers.process_dates import process_dates from bot.retrievers.utils.load_hyperparams import load_hyperparams from llama_index import QueryBundle +from llama_index.core import BaseQueryEngine from llama_index.vector_stores import ExactMatchFilter, FilterCondition, MetadataFilters from tc_hivemind_backend.embeddings.cohere import CohereEmbedding from tc_hivemind_backend.pg_vector_access import PGVectorAccess -def query_discord( +def create_discord_engine( community_id: str, - query: str, thread_names: list[str], channel_names: list[str], days: list[str], similarity_top_k: int | None = None, -) -> str: +) -> BaseQueryEngine: """ query the discord database using filters given and give an anwer to the given query using the LLM @@ -37,18 +37,16 @@ def query_discord( Returns --------- - response : str - the LLM response given the query + query_engine : BaseQueryEngine + the created query engine with the filters """ - if similarity_top_k is None: - _, similarity_top_k, _ = load_hyperparams() - table_name = "discord" dbname = f"community_{community_id}" pg_vector = PGVectorAccess(table_name=table_name, dbname=dbname) - index = pg_vector.load_index() + if similarity_top_k is None: + _, similarity_top_k, _ = load_hyperparams() thread_filters: list[ExactMatchFilter] = [] channel_filters: list[ExactMatchFilter] = [] @@ -76,22 +74,17 @@ def query_discord( filters=filters, similarity_top_k=similarity_top_k ) - query_bundle = QueryBundle( - query_str=query, embedding=CohereEmbedding().get_text_embedding(text=query) - ) - response = query_engine.query(query_bundle) + return query_engine - return response.response - -def query_discord_auto_filter( +def create_discord_engine_auto_filter( community_id: str, query: str, similarity_top_k: int | None = None, d: int | None = None, -) -> str: +) -> BaseQueryEngine: """ - get the query results and do the filtering automatically. + get the query engine and do the filtering automatically. By automatically we mean, it would first query the summaries to get the metadata filters @@ -106,14 +99,14 @@ def query_discord_auto_filter( to get the `k2` count simliar nodes if `None`, then would read from `.env` d : int - this would make the secondary search (`query_discord`) + this would make the secondary search (`create_discord_engine`) to be done on the `metadata.date - d` to `metadata.date + d` Returns --------- - response : str - the LLM response given the query + query_engine : BaseQueryEngine + the created query engine with the filters """ table_name = "discord_summary" dbname = f"community_{community_id}" @@ -135,11 +128,39 @@ def query_discord_auto_filter( dates_modified = process_dates(list(dates), d) - response = query_discord( + engine = create_discord_engine( community_id=community_id, query=query, thread_names=list(threads), channel_names=list(channels), days=dates_modified, ) + return engine + + +def query_discord( + community_id: str, + query: str, +) -> str: + """ + query the llm using the query engine + + Parameters + ------------ + query_engine : BaseQueryEngine + the prepared query engine + query : str + the string question + """ + query_engine = create_discord_engine_auto_filter( + community_id=community_id, + query=query, + ) + + query_bundle = QueryBundle( + query_str=query, embedding=CohereEmbedding().get_text_embedding(text=query) + ) + + response = query_engine.query(query_bundle) + return response From ede353c36147e497495739386bd2ef30a2c91f80 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Mon, 1 Jan 2024 14:43:39 +0330 Subject: [PATCH 11/25] feat: Added subquery reponse generator! --- discord_query.py | 142 +-------------------- requirements.txt | 1 + subquery.py | 104 +++++++++++++++ utils/query_engine/__init__.py | 2 + utils/query_engine/discord_query_engine.py | 136 ++++++++++++++++++++ 5 files changed, 245 insertions(+), 140 deletions(-) create mode 100644 subquery.py create mode 100644 utils/query_engine/__init__.py create mode 100644 utils/query_engine/discord_query_engine.py diff --git a/discord_query.py b/discord_query.py index aedfe79..b6ec65b 100644 --- a/discord_query.py +++ b/discord_query.py @@ -1,141 +1,6 @@ -from bot.retrievers.forum_summary_retriever import ForumBasedSummaryRetriever -from bot.retrievers.process_dates import process_dates -from bot.retrievers.utils.load_hyperparams import load_hyperparams +from utils.query_engine.discord_query_engine import prepare_discord_engine_auto_filter from llama_index import QueryBundle -from llama_index.core import BaseQueryEngine -from llama_index.vector_stores import ExactMatchFilter, FilterCondition, MetadataFilters from tc_hivemind_backend.embeddings.cohere import CohereEmbedding -from tc_hivemind_backend.pg_vector_access import PGVectorAccess - - -def create_discord_engine( - community_id: str, - thread_names: list[str], - channel_names: list[str], - days: list[str], - similarity_top_k: int | None = None, -) -> BaseQueryEngine: - """ - query the discord database using filters given - and give an anwer to the given query using the LLM - - Parameters - ------------ - guild_id : str - the discord guild data to query - query : str - the query (question) of the user - thread_names : list[str] - the given threads to search for - channel_names : list[str] - the given channels to search for - days : list[str] - the given days to search for - similarity_top_k : int | None - the k similar results to use when querying the data - if `None` will load from `.env` file - - Returns - --------- - query_engine : BaseQueryEngine - the created query engine with the filters - """ - table_name = "discord" - dbname = f"community_{community_id}" - - pg_vector = PGVectorAccess(table_name=table_name, dbname=dbname) - index = pg_vector.load_index() - if similarity_top_k is None: - _, similarity_top_k, _ = load_hyperparams() - - thread_filters: list[ExactMatchFilter] = [] - channel_filters: list[ExactMatchFilter] = [] - day_filters: list[ExactMatchFilter] = [] - - for channel in channel_names: - channel_updated = channel.replace("'", "''") - channel_filters.append(ExactMatchFilter(key="channel", value=channel_updated)) - - for thread in thread_names: - thread_updated = thread.replace("'", "''") - thread_filters.append(ExactMatchFilter(key="thread", value=thread_updated)) - - for day in days: - day_filters.append(ExactMatchFilter(key="date", value=day)) - - all_filters: list[ExactMatchFilter] = [] - all_filters.extend(thread_filters) - all_filters.extend(channel_filters) - all_filters.extend(day_filters) - - filters = MetadataFilters(filters=all_filters, condition=FilterCondition.OR) - - query_engine = index.as_query_engine( - filters=filters, similarity_top_k=similarity_top_k - ) - - return query_engine - - -def create_discord_engine_auto_filter( - community_id: str, - query: str, - similarity_top_k: int | None = None, - d: int | None = None, -) -> BaseQueryEngine: - """ - get the query engine and do the filtering automatically. - By automatically we mean, it would first query the summaries - to get the metadata filters - - Parameters - ----------- - guild_id : str - the discord guild data to query - query : str - the query (question) of the user - similarity_top_k : int | None - the value for the initial summary search - to get the `k2` count simliar nodes - if `None`, then would read from `.env` - d : int - this would make the secondary search (`create_discord_engine`) - to be done on the `metadata.date - d` to `metadata.date + d` - - - Returns - --------- - query_engine : BaseQueryEngine - the created query engine with the filters - """ - table_name = "discord_summary" - dbname = f"community_{community_id}" - - if d is None: - _, _, d = load_hyperparams() - if similarity_top_k is None: - similarity_top_k, _, _ = load_hyperparams() - - discord_retriever = ForumBasedSummaryRetriever(table_name=table_name, dbname=dbname) - - channels, threads, dates = discord_retriever.retreive_metadata( - query=query, - metadata_group1_key="channel", - metadata_group2_key="thread", - metadata_date_key="date", - similarity_top_k=similarity_top_k, - ) - - dates_modified = process_dates(list(dates), d) - - engine = create_discord_engine( - community_id=community_id, - query=query, - thread_names=list(threads), - channel_names=list(channels), - days=dates_modified, - ) - return engine def query_discord( @@ -152,15 +17,12 @@ def query_discord( query : str the string question """ - query_engine = create_discord_engine_auto_filter( + query_engine = prepare_discord_engine_auto_filter( community_id=community_id, query=query, ) - query_bundle = QueryBundle( query_str=query, embedding=CohereEmbedding().get_text_embedding(text=query) ) - response = query_engine.query(query_bundle) - return response diff --git a/requirements.txt b/requirements.txt index f04f1e7..f41372e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,3 +18,4 @@ pytest>=7.4.3, <8.0.0 python-dotenv==1.0.0 tc-hivemind-backend==1.0.0 celery>=5.3.6, <6.0.0 +guidance diff --git a/subquery.py b/subquery.py new file mode 100644 index 0000000..ad3f64a --- /dev/null +++ b/subquery.py @@ -0,0 +1,104 @@ +from utils.query_engine import prepare_discord_engine_auto_filter +from llama_index.core import BaseQueryEngine +from guidance.llms import OpenAI as GuidanceOpenAI +from llama_index import QueryBundle +from llama_index.tools import QueryEngineTool, ToolMetadata +from llama_index.query_engine import SubQuestionQueryEngine +from llama_index.question_gen.guidance_generator import GuidanceQuestionGenerator + +def query_multiple_source( + query: str, + community_id: str, + discord: bool, + discourse: bool, + gdrive: bool, + notion: bool, + telegram: bool, + github: bool, + ) -> str: + """ + query multiple platforms and get an answer from the multiple + + Parameters + ------------ + query : str + the user question + community_id : str + the community id to get their data + discord : bool + if `True` then add the engine to the subquery_generator + discourse : bool + if `True` then add the engine to the subquery_generator + gdrive : bool + if `True` then add the engine to the subquery_generator + notion : bool + if `True` then add the engine to the subquery_generator + telegram : bool + if `True` then add the engine to the subquery_generator + github : bool + if `True` then add the engine to the subquery_generator + + + Returns + -------- + reponse : str + the response to the user query from the LLM + using the engines of the given platforms (pltform equal to True) + """ + query_engine_tools: list[QueryEngineTool] = [] + tools: list[ToolMetadata] = [] + + discord_query_engine: BaseQueryEngine + discourse_query_engine: BaseQueryEngine + gdrive_query_engine: BaseQueryEngine + notion_query_engine: BaseQueryEngine + telegram_query_engine: BaseQueryEngine + github_query_engine: BaseQueryEngine + + # query engine perparation + # tools_metadata and query_engine_tools + if discord: + discord_query_engine = prepare_discord_engine_auto_filter( + community_id, + query, + similarity_top_k=None, + d=None, + ) + tool_metadata = ToolMetadata( + name="Discord", + description="Provides the discord platform conversations data." + ) + + tools.append(tool_metadata) + query_engine_tools.append( + QueryEngineTool( + query_engine=discord_query_engine, + metadata=tool_metadata, + ) + ) + + if discourse: + raise NotImplementedError + if gdrive: + raise NotImplementedError + if notion: + raise NotImplementedError + if telegram: + raise NotImplementedError + if github: + raise NotImplementedError + + + question_gen = GuidanceQuestionGenerator.from_defaults( + guidance_llm=GuidanceOpenAI("text-davinci-003"), verbose=False + ) + + s_engine = SubQuestionQueryEngine.from_defaults( + question_gen=question_gen, + query_engine_tools=query_engine_tools, + ) + reponse = s_engine.query( + QueryBundle(query) + ) + + return reponse.response \ No newline at end of file diff --git a/utils/query_engine/__init__.py b/utils/query_engine/__init__.py new file mode 100644 index 0000000..0de4592 --- /dev/null +++ b/utils/query_engine/__init__.py @@ -0,0 +1,2 @@ +# flake8: noqa +from discord_query_engine import prepare_discord_engine_auto_filter \ No newline at end of file diff --git a/utils/query_engine/discord_query_engine.py b/utils/query_engine/discord_query_engine.py new file mode 100644 index 0000000..0af3064 --- /dev/null +++ b/utils/query_engine/discord_query_engine.py @@ -0,0 +1,136 @@ +from bot.retrievers.forum_summary_retriever import ForumBasedSummaryRetriever +from bot.retrievers.process_dates import process_dates +from bot.retrievers.utils.load_hyperparams import load_hyperparams +from llama_index.core import BaseQueryEngine +from llama_index.vector_stores import ExactMatchFilter, FilterCondition, MetadataFilters +from tc_hivemind_backend.pg_vector_access import PGVectorAccess + + +def prepare_discord_engine( + community_id: str, + thread_names: list[str], + channel_names: list[str], + days: list[str], + similarity_top_k: int | None = None, +) -> BaseQueryEngine: + """ + query the discord database using filters given + and give an anwer to the given query using the LLM + + Parameters + ------------ + guild_id : str + the discord guild data to query + query : str + the query (question) of the user + thread_names : list[str] + the given threads to search for + channel_names : list[str] + the given channels to search for + days : list[str] + the given days to search for + similarity_top_k : int | None + the k similar results to use when querying the data + if `None` will load from `.env` file + + Returns + --------- + query_engine : BaseQueryEngine + the created query engine with the filters + """ + table_name = "discord" + dbname = f"community_{community_id}" + + pg_vector = PGVectorAccess(table_name=table_name, dbname=dbname) + index = pg_vector.load_index() + if similarity_top_k is None: + _, similarity_top_k, _ = load_hyperparams() + + thread_filters: list[ExactMatchFilter] = [] + channel_filters: list[ExactMatchFilter] = [] + day_filters: list[ExactMatchFilter] = [] + + for channel in channel_names: + channel_updated = channel.replace("'", "''") + channel_filters.append(ExactMatchFilter(key="channel", value=channel_updated)) + + for thread in thread_names: + thread_updated = thread.replace("'", "''") + thread_filters.append(ExactMatchFilter(key="thread", value=thread_updated)) + + for day in days: + day_filters.append(ExactMatchFilter(key="date", value=day)) + + all_filters: list[ExactMatchFilter] = [] + all_filters.extend(thread_filters) + all_filters.extend(channel_filters) + all_filters.extend(day_filters) + + filters = MetadataFilters(filters=all_filters, condition=FilterCondition.OR) + + query_engine = index.as_query_engine( + filters=filters, similarity_top_k=similarity_top_k + ) + + return query_engine + + +def prepare_discord_engine_auto_filter( + community_id: str, + query: str, + similarity_top_k: int | None = None, + d: int | None = None, +) -> BaseQueryEngine: + """ + get the query engine and do the filtering automatically. + By automatically we mean, it would first query the summaries + to get the metadata filters + + Parameters + ----------- + guild_id : str + the discord guild data to query + query : str + the query (question) of the user + similarity_top_k : int | None + the value for the initial summary search + to get the `k2` count simliar nodes + if `None`, then would read from `.env` + d : int + this would make the secondary search (`prepare_discord_engine`) + to be done on the `metadata.date - d` to `metadata.date + d` + + + Returns + --------- + query_engine : BaseQueryEngine + the created query engine with the filters + """ + table_name = "discord_summary" + dbname = f"community_{community_id}" + + if d is None: + _, _, d = load_hyperparams() + if similarity_top_k is None: + similarity_top_k, _, _ = load_hyperparams() + + discord_retriever = ForumBasedSummaryRetriever(table_name=table_name, dbname=dbname) + + channels, threads, dates = discord_retriever.retreive_metadata( + query=query, + metadata_group1_key="channel", + metadata_group2_key="thread", + metadata_date_key="date", + similarity_top_k=similarity_top_k, + ) + + dates_modified = process_dates(list(dates), d) + + engine = prepare_discord_engine( + community_id=community_id, + query=query, + thread_names=list(threads), + channel_names=list(channels), + days=dates_modified, + ) + return engine \ No newline at end of file From 073cbb7e5c0b5b4c1bd15722b465089c3695fd33 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Mon, 1 Jan 2024 15:01:36 +0330 Subject: [PATCH 12/25] fix: wrong import and cleaning code for linters! --- subquery.py | 32 ++++++++++------------ utils/query_engine/__init__.py | 2 +- utils/query_engine/discord_query_engine.py | 2 +- 3 files changed, 17 insertions(+), 19 deletions(-) diff --git a/subquery.py b/subquery.py index ad3f64a..d954934 100644 --- a/subquery.py +++ b/subquery.py @@ -1,21 +1,22 @@ from utils.query_engine import prepare_discord_engine_auto_filter from llama_index.core import BaseQueryEngine -from guidance.llms import OpenAI as GuidanceOpenAI +from guidance.models import OpenAI as GuidanceOpenAI from llama_index import QueryBundle from llama_index.tools import QueryEngineTool, ToolMetadata from llama_index.query_engine import SubQuestionQueryEngine from llama_index.question_gen.guidance_generator import GuidanceQuestionGenerator + def query_multiple_source( - query: str, - community_id: str, - discord: bool, - discourse: bool, - gdrive: bool, - notion: bool, - telegram: bool, - github: bool, - ) -> str: + query: str, + community_id: str, + discord: bool, + discourse: bool, + gdrive: bool, + notion: bool, + telegram: bool, + github: bool, +) -> str: """ query multiple platforms and get an answer from the multiple @@ -66,7 +67,7 @@ def query_multiple_source( ) tool_metadata = ToolMetadata( name="Discord", - description="Provides the discord platform conversations data." + description="Provides the discord platform conversations data.", ) tools.append(tool_metadata) @@ -76,7 +77,7 @@ def query_multiple_source( metadata=tool_metadata, ) ) - + if discourse: raise NotImplementedError if gdrive: @@ -87,7 +88,6 @@ def query_multiple_source( raise NotImplementedError if github: raise NotImplementedError - question_gen = GuidanceQuestionGenerator.from_defaults( guidance_llm=GuidanceOpenAI("text-davinci-003"), verbose=False @@ -97,8 +97,6 @@ def query_multiple_source( question_gen=question_gen, query_engine_tools=query_engine_tools, ) - reponse = s_engine.query( - QueryBundle(query) - ) + reponse = s_engine.query(QueryBundle(query)) - return reponse.response \ No newline at end of file + return reponse.response diff --git a/utils/query_engine/__init__.py b/utils/query_engine/__init__.py index 0de4592..115169c 100644 --- a/utils/query_engine/__init__.py +++ b/utils/query_engine/__init__.py @@ -1,2 +1,2 @@ # flake8: noqa -from discord_query_engine import prepare_discord_engine_auto_filter \ No newline at end of file +from discord_query_engine import prepare_discord_engine_auto_filter diff --git a/utils/query_engine/discord_query_engine.py b/utils/query_engine/discord_query_engine.py index 0af3064..56496cb 100644 --- a/utils/query_engine/discord_query_engine.py +++ b/utils/query_engine/discord_query_engine.py @@ -133,4 +133,4 @@ def prepare_discord_engine_auto_filter( channel_names=list(channels), days=dates_modified, ) - return engine \ No newline at end of file + return engine From 36cd0dcc17fab59cc499f967aa8ceac931537bea Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Mon, 1 Jan 2024 15:23:56 +0330 Subject: [PATCH 13/25] feat: Applying cohere embedding! --- subquery.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/subquery.py b/subquery.py index d954934..1e3ef9c 100644 --- a/subquery.py +++ b/subquery.py @@ -5,6 +5,7 @@ from llama_index.tools import QueryEngineTool, ToolMetadata from llama_index.query_engine import SubQuestionQueryEngine from llama_index.question_gen.guidance_generator import GuidanceQuestionGenerator +from tc_hivemind_backend.embeddings.cohere import CohereEmbedding def query_multiple_source( @@ -97,6 +98,10 @@ def query_multiple_source( question_gen=question_gen, query_engine_tools=query_engine_tools, ) - reponse = s_engine.query(QueryBundle(query)) + reponse = s_engine.query( + QueryBundle( + query_str=query, embedding=CohereEmbedding().get_text_embedding(text=query) + ) + ) return reponse.response From 058b528c1de2fde9fb65d412c2e80d7d790d3f61 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Tue, 2 Jan 2024 10:30:48 +0330 Subject: [PATCH 14/25] update: test case for discord secondary search! --- .../unit/test_prepare_discord_query_engine.py | 50 +++++++++++++++++++ utils/query_engine/__init__.py | 2 +- utils/query_engine/discord_query_engine.py | 8 ++- 3 files changed, 58 insertions(+), 2 deletions(-) create mode 100644 tests/unit/test_prepare_discord_query_engine.py diff --git a/tests/unit/test_prepare_discord_query_engine.py b/tests/unit/test_prepare_discord_query_engine.py new file mode 100644 index 0000000..7ac1182 --- /dev/null +++ b/tests/unit/test_prepare_discord_query_engine.py @@ -0,0 +1,50 @@ +import unittest +import os +from unittest.mock import patch, Mock +from utils.query_engine.discord_query_engine import prepare_discord_engine +from llama_index.core import BaseQueryEngine +from llama_index.vector_stores import ExactMatchFilter, FilterCondition, MetadataFilters + + +class TestPrepareDiscordEngine(unittest.TestCase): + def setUp(self): + # Set up environment variables for testing + os.environ["CHUNK_SIZE"] = "128" + os.environ["EMBEDDING_DIM"] = "256" + os.environ["K1_RETRIEVER_SEARCH"] = "20" + os.environ["K2_RETRIEVER_SEARCH"] = "5" + os.environ["D_RETRIEVER_SEARCH"] = "3" + + def test_prepare_discord_engine(self): + community_id = "123456" + thread_names = ["thread1", "thread2"] + channel_names = ["channel1", "channel2"] + days = ["2022-01-01", "2022-01-02"] + + # Call the function + query_engine = prepare_discord_engine( + community_id, + thread_names, + channel_names, + days, + testing=True, + ) + + # Assertions + self.assertIsInstance(query_engine, BaseQueryEngine) + + expected_filter = MetadataFilters( + filters=[ + ExactMatchFilter(key="thread", value="thread1"), + ExactMatchFilter(key="thread", value="thread2"), + ExactMatchFilter(key="channel", value="channel1"), + ExactMatchFilter(key="channel", value="channel2"), + ExactMatchFilter(key="date", value="2022-01-01"), + ExactMatchFilter(key="date", value="2022-01-02"), + ], + condition=FilterCondition.OR, + ) + + self.assertEqual(query_engine.retriever._filters, expected_filter) + # this is the secondary search, so K2 should be for this + self.assertEqual(query_engine.retriever._similarity_top_k, 5) diff --git a/utils/query_engine/__init__.py b/utils/query_engine/__init__.py index 115169c..fad06f9 100644 --- a/utils/query_engine/__init__.py +++ b/utils/query_engine/__init__.py @@ -1,2 +1,2 @@ # flake8: noqa -from discord_query_engine import prepare_discord_engine_auto_filter +from .discord_query_engine import prepare_discord_engine_auto_filter diff --git a/utils/query_engine/discord_query_engine.py b/utils/query_engine/discord_query_engine.py index 56496cb..4b121df 100644 --- a/utils/query_engine/discord_query_engine.py +++ b/utils/query_engine/discord_query_engine.py @@ -12,6 +12,7 @@ def prepare_discord_engine( channel_names: list[str], days: list[str], similarity_top_k: int | None = None, + **kwarg, ) -> BaseQueryEngine: """ query the discord database using filters given @@ -32,6 +33,9 @@ def prepare_discord_engine( similarity_top_k : int | None the k similar results to use when querying the data if `None` will load from `.env` file + ** kwargs : + testing : bool + whether to setup the PGVectorAccess in testing mode Returns --------- @@ -41,7 +45,9 @@ def prepare_discord_engine( table_name = "discord" dbname = f"community_{community_id}" - pg_vector = PGVectorAccess(table_name=table_name, dbname=dbname) + testing = kwarg.get("testing", False) + + pg_vector = PGVectorAccess(table_name=table_name, dbname=dbname, testing=testing) index = pg_vector.load_index() if similarity_top_k is None: _, similarity_top_k, _ = load_hyperparams() From 0b19e7bf7695395c4f188a46ddc43faf09cb91a2 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Tue, 2 Jan 2024 10:49:42 +0330 Subject: [PATCH 15/25] fix: linter issues based on superlinter rules! --- discord_query.py | 2 +- subquery.py | 16 ++++++++-------- tests/unit/test_prepare_discord_query_engine.py | 1 - 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/discord_query.py b/discord_query.py index b6ec65b..a630f7e 100644 --- a/discord_query.py +++ b/discord_query.py @@ -1,6 +1,6 @@ -from utils.query_engine.discord_query_engine import prepare_discord_engine_auto_filter from llama_index import QueryBundle from tc_hivemind_backend.embeddings.cohere import CohereEmbedding +from utils.query_engine.discord_query_engine import prepare_discord_engine_auto_filter def query_discord( diff --git a/subquery.py b/subquery.py index 1e3ef9c..cf4fa10 100644 --- a/subquery.py +++ b/subquery.py @@ -1,11 +1,11 @@ -from utils.query_engine import prepare_discord_engine_auto_filter -from llama_index.core import BaseQueryEngine from guidance.models import OpenAI as GuidanceOpenAI from llama_index import QueryBundle -from llama_index.tools import QueryEngineTool, ToolMetadata +from llama_index.core import BaseQueryEngine from llama_index.query_engine import SubQuestionQueryEngine from llama_index.question_gen.guidance_generator import GuidanceQuestionGenerator +from llama_index.tools import QueryEngineTool, ToolMetadata from tc_hivemind_backend.embeddings.cohere import CohereEmbedding +from utils.query_engine import prepare_discord_engine_auto_filter def query_multiple_source( @@ -51,11 +51,11 @@ def query_multiple_source( tools: list[ToolMetadata] = [] discord_query_engine: BaseQueryEngine - discourse_query_engine: BaseQueryEngine - gdrive_query_engine: BaseQueryEngine - notion_query_engine: BaseQueryEngine - telegram_query_engine: BaseQueryEngine - github_query_engine: BaseQueryEngine + # discourse_query_engine: BaseQueryEngine + # gdrive_query_engine: BaseQueryEngine + # notion_query_engine: BaseQueryEngine + # telegram_query_engine: BaseQueryEngine + # github_query_engine: BaseQueryEngine # query engine perparation # tools_metadata and query_engine_tools diff --git a/tests/unit/test_prepare_discord_query_engine.py b/tests/unit/test_prepare_discord_query_engine.py index 7ac1182..26816f1 100644 --- a/tests/unit/test_prepare_discord_query_engine.py +++ b/tests/unit/test_prepare_discord_query_engine.py @@ -1,6 +1,5 @@ import unittest import os -from unittest.mock import patch, Mock from utils.query_engine.discord_query_engine import prepare_discord_engine from llama_index.core import BaseQueryEngine from llama_index.vector_stores import ExactMatchFilter, FilterCondition, MetadataFilters From da6bb2ca9fa2ca85a4d2bd10eb84274dea59a931 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Tue, 2 Jan 2024 11:02:47 +0330 Subject: [PATCH 16/25] fix: isort linter issue! --- tests/unit/test_prepare_discord_query_engine.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_prepare_discord_query_engine.py b/tests/unit/test_prepare_discord_query_engine.py index 26816f1..72ad1a6 100644 --- a/tests/unit/test_prepare_discord_query_engine.py +++ b/tests/unit/test_prepare_discord_query_engine.py @@ -1,8 +1,9 @@ -import unittest import os -from utils.query_engine.discord_query_engine import prepare_discord_engine +import unittest + from llama_index.core import BaseQueryEngine from llama_index.vector_stores import ExactMatchFilter, FilterCondition, MetadataFilters +from utils.query_engine.discord_query_engine import prepare_discord_engine class TestPrepareDiscordEngine(unittest.TestCase): From d75342bae2c112d14d327694ea985d71a88e1d78 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Tue, 2 Jan 2024 15:34:27 +0330 Subject: [PATCH 17/25] feat: Added source node returning! --- discord_query.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/discord_query.py b/discord_query.py index a630f7e..7a61c32 100644 --- a/discord_query.py +++ b/discord_query.py @@ -1,4 +1,5 @@ from llama_index import QueryBundle +from llama_index.schema import NodeWithScore from tc_hivemind_backend.embeddings.cohere import CohereEmbedding from utils.query_engine.discord_query_engine import prepare_discord_engine_auto_filter @@ -6,7 +7,7 @@ def query_discord( community_id: str, query: str, -) -> str: +) -> tuple[str, list[NodeWithScore]]: """ query the llm using the query engine @@ -25,4 +26,4 @@ def query_discord( query_str=query, embedding=CohereEmbedding().get_text_embedding(text=query) ) response = query_engine.query(query_bundle) - return response + return response.response, response.source_nodes From fe95fbff2c45c59827bb1a2e85178a54d830aebc Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Tue, 2 Jan 2024 15:34:51 +0330 Subject: [PATCH 18/25] feat: Added credentials! --- .env.example | 26 ++++++++++++++++++++++++++ docker-compose.test.yml | 2 ++ 2 files changed, 28 insertions(+) create mode 100644 .env.example diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..ccba142 --- /dev/null +++ b/.env.example @@ -0,0 +1,26 @@ +PORT= +MONGODB_HOST= +MONGODB_PORT= +MONGODB_USER= +MONGODB_PASS= +NEO4J_PROTOCOL= +NEO4J_HOST= +NEO4J_PORT= +NEO4J_USER= +NEO4J_PASSWORD= +NEO4J_DB= +POSTGRES_HOST= +POSTGRES_USER= +POSTGRES_PASS= +POSTGRES_PORT= +RABBIT_HOST= +RABBIT_PORT= +RABBIT_USER= +RABBIT_PASSWORD= +CHUNK_SIZE= +EMBEDDING_DIM= +K1_RETRIEVER_SEARCH= +K2_RETRIEVER_SEARCH= +D_RETRIEVER_SEARCH= +COHERE_API_KEY= +OPENAI_API_KEY= \ No newline at end of file diff --git a/docker-compose.test.yml b/docker-compose.test.yml index 4e046d9..6375962 100644 --- a/docker-compose.test.yml +++ b/docker-compose.test.yml @@ -31,6 +31,8 @@ services: - K1_RETRIEVER_SEARCH=20 - K2_RETRIEVER_SEARCH=5 - D_RETRIEVER_SEARCH=7 + - COHERE_API_KEY=some_credentials + - OPENAI_API_KEY=some_credentials2 volumes: - ./coverage:/project/coverage depends_on: From 923a3cb964da849dc491e3f3ee962d37b745fafc Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Tue, 2 Jan 2024 15:40:50 +0330 Subject: [PATCH 19/25] feat: completing the function doc! --- discord_query.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/discord_query.py b/discord_query.py index 7a61c32..24e762f 100644 --- a/discord_query.py +++ b/discord_query.py @@ -17,6 +17,13 @@ def query_discord( the prepared query engine query : str the string question + + Returns + ---------- + response : str + the LLM response + source_nodes : list[llama_index.schema.NodeWithScore] + the source nodes that helped in answering the question """ query_engine = prepare_discord_engine_auto_filter( community_id=community_id, From 43cb8670c3f3ea57c36caa29f4ede5d5ded218bc Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Tue, 2 Jan 2024 15:48:15 +0330 Subject: [PATCH 20/25] fix: dotenv-linter issue! --- .env.example | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/.env.example b/.env.example index ccba142..e73ede4 100644 --- a/.env.example +++ b/.env.example @@ -1,26 +1,25 @@ -PORT= +CHUNK_SIZE= +COHERE_API_KEY= +D_RETRIEVER_SEARCH= +EMBEDDING_DIM= +K1_RETRIEVER_SEARCH= +K2_RETRIEVER_SEARCH= MONGODB_HOST= +MONGODB_PASS= MONGODB_PORT= MONGODB_USER= -MONGODB_PASS= -NEO4J_PROTOCOL= +NEO4J_DB= NEO4J_HOST= +NEO4J_PASSWORD= NEO4J_PORT= +NEO4J_PROTOCOL= NEO4J_USER= -NEO4J_PASSWORD= -NEO4J_DB= +OPENAI_API_KEY= POSTGRES_HOST= -POSTGRES_USER= POSTGRES_PASS= POSTGRES_PORT= +POSTGRES_USER= RABBIT_HOST= +RABBIT_PASSWORD= RABBIT_PORT= RABBIT_USER= -RABBIT_PASSWORD= -CHUNK_SIZE= -EMBEDDING_DIM= -K1_RETRIEVER_SEARCH= -K2_RETRIEVER_SEARCH= -D_RETRIEVER_SEARCH= -COHERE_API_KEY= -OPENAI_API_KEY= \ No newline at end of file From 51c4c23636ae85f80b779d5cb4008709ada8c4b8 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Wed, 3 Jan 2024 08:41:13 +0330 Subject: [PATCH 21/25] feat: update discord platform description! --- subquery.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/subquery.py b/subquery.py index cf4fa10..8f3f3d2 100644 --- a/subquery.py +++ b/subquery.py @@ -68,7 +68,7 @@ def query_multiple_source( ) tool_metadata = ToolMetadata( name="Discord", - description="Provides the discord platform conversations data.", + description="Contains messages and summaries of conversations from the Discord platform of the community", ) tools.append(tool_metadata) From 6f51cd5aad10c91408d7449cea6205c2195b756d Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Thu, 4 Jan 2024 16:28:16 +0330 Subject: [PATCH 22/25] feat: Added cohere embedding model and updated subquery - We needed to apply the cohere embedding model in our codes. - The `subquery.py` updated based on little experiments (both embedding model and function output updated). --- bot/retrievers/summary_retriever_base.py | 10 +++++++--- subquery.py | 23 ++++++++++++---------- utils/query_engine/discord_query_engine.py | 8 +++++++- 3 files changed, 27 insertions(+), 14 deletions(-) diff --git a/bot/retrievers/summary_retriever_base.py b/bot/retrievers/summary_retriever_base.py index 1cc3420..0095a7f 100644 --- a/bot/retrievers/summary_retriever_base.py +++ b/bot/retrievers/summary_retriever_base.py @@ -34,7 +34,7 @@ def __init__( the embedding model to use for doing embedding on the query string default would be CohereEmbedding that we've written """ - self.index = self._setup_index(table_name, dbname) + self.index = self._setup_index(table_name, dbname, embedding_model) self.embedding_model = embedding_model def get_similar_nodes( @@ -62,10 +62,14 @@ def get_similar_nodes( return nodes - def _setup_index(self, table_name: str, dbname: str) -> VectorStoreIndex: + def _setup_index( + self, table_name: str, dbname: str, embedding_model: BaseEmbedding + ) -> VectorStoreIndex: """ setup the llama_index VectorStoreIndex """ - pg_vector_access = PGVectorAccess(table_name=table_name, dbname=dbname) + pg_vector_access = PGVectorAccess( + table_name=table_name, dbname=dbname, embed_model=embedding_model + ) index = pg_vector_access.load_index() return index diff --git a/subquery.py b/subquery.py index 8f3f3d2..73cbb55 100644 --- a/subquery.py +++ b/subquery.py @@ -1,8 +1,9 @@ from guidance.models import OpenAI as GuidanceOpenAI -from llama_index import QueryBundle +from llama_index import QueryBundle, ServiceContext from llama_index.core import BaseQueryEngine from llama_index.query_engine import SubQuestionQueryEngine from llama_index.question_gen.guidance_generator import GuidanceQuestionGenerator +from llama_index.schema import NodeWithScore from llama_index.tools import QueryEngineTool, ToolMetadata from tc_hivemind_backend.embeddings.cohere import CohereEmbedding from utils.query_engine import prepare_discord_engine_auto_filter @@ -17,7 +18,7 @@ def query_multiple_source( notion: bool, telegram: bool, github: bool, -) -> str: +) -> tuple[str, list[NodeWithScore]]: """ query multiple platforms and get an answer from the multiple @@ -43,9 +44,11 @@ def query_multiple_source( Returns -------- - reponse : str + response : str, the response to the user query from the LLM using the engines of the given platforms (pltform equal to True) + source_nodes : list[NodeWithScore] + the list of nodes that were source of answering """ query_engine_tools: list[QueryEngineTool] = [] tools: list[ToolMetadata] = [] @@ -93,15 +96,15 @@ def query_multiple_source( question_gen = GuidanceQuestionGenerator.from_defaults( guidance_llm=GuidanceOpenAI("text-davinci-003"), verbose=False ) - + embed_model = CohereEmbedding() + service_context = ServiceContext.from_defaults(embed_model=embed_model) s_engine = SubQuestionQueryEngine.from_defaults( question_gen=question_gen, query_engine_tools=query_engine_tools, + use_async=False, + service_context=service_context, ) - reponse = s_engine.query( - QueryBundle( - query_str=query, embedding=CohereEmbedding().get_text_embedding(text=query) - ) - ) + query_embedding = embed_model.get_text_embedding(text=query) + response = s_engine.query(QueryBundle(query_str=query, embedding=query_embedding)) - return reponse.response + return response.response, response.source_nodes diff --git a/utils/query_engine/discord_query_engine.py b/utils/query_engine/discord_query_engine.py index 4b121df..ca032c3 100644 --- a/utils/query_engine/discord_query_engine.py +++ b/utils/query_engine/discord_query_engine.py @@ -4,6 +4,7 @@ from llama_index.core import BaseQueryEngine from llama_index.vector_stores import ExactMatchFilter, FilterCondition, MetadataFilters from tc_hivemind_backend.pg_vector_access import PGVectorAccess +from tc_hivemind_backend.embeddings.cohere import CohereEmbedding def prepare_discord_engine( @@ -47,7 +48,12 @@ def prepare_discord_engine( testing = kwarg.get("testing", False) - pg_vector = PGVectorAccess(table_name=table_name, dbname=dbname, testing=testing) + pg_vector = PGVectorAccess( + table_name=table_name, + dbname=dbname, + testing=testing, + embed_model=CohereEmbedding(), + ) index = pg_vector.load_index() if similarity_top_k is None: _, similarity_top_k, _ = load_hyperparams() From 243de104c62bdc402c9281b43ab11cebf4d224c7 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Thu, 4 Jan 2024 16:35:25 +0330 Subject: [PATCH 23/25] fix: isort linter issue based on superlinter rules! --- utils/query_engine/discord_query_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/query_engine/discord_query_engine.py b/utils/query_engine/discord_query_engine.py index ca032c3..6a29833 100644 --- a/utils/query_engine/discord_query_engine.py +++ b/utils/query_engine/discord_query_engine.py @@ -3,8 +3,8 @@ from bot.retrievers.utils.load_hyperparams import load_hyperparams from llama_index.core import BaseQueryEngine from llama_index.vector_stores import ExactMatchFilter, FilterCondition, MetadataFilters -from tc_hivemind_backend.pg_vector_access import PGVectorAccess from tc_hivemind_backend.embeddings.cohere import CohereEmbedding +from tc_hivemind_backend.pg_vector_access import PGVectorAccess def prepare_discord_engine( From e469d3d2fd63b10aebc1730379c787aa1f5ced3c Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Thu, 4 Jan 2024 16:42:14 +0330 Subject: [PATCH 24/25] update: shared codes lib version! We have added the custom embed model support in its newer version. --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index f41372e..a780b05 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,6 +16,6 @@ neo4j>=5.14.1, <6.0.0 coverage>=7.3.3, <8.0.0 pytest>=7.4.3, <8.0.0 python-dotenv==1.0.0 -tc-hivemind-backend==1.0.0 +tc-hivemind-backend==1.1.0 celery>=5.3.6, <6.0.0 guidance From 34ceb7ce111cbb4619b7501839b90775f5cd53a0 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Mon, 8 Jan 2024 09:07:27 +0330 Subject: [PATCH 25/25] update: llama-index lib usage! we updated the library to the newest right version and we're chosed the right LLM for the guidance. note: the guidance_llm would create the subqueries. --- requirements.txt | 2 +- subquery.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/requirements.txt b/requirements.txt index a780b05..0a35833 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ numpy -llama-index>=0.9.21, <1.0.0 +llama-index>=0.9.26, <1.0.0 pymongo python-dotenv pgvector diff --git a/subquery.py b/subquery.py index 73cbb55..58dfd21 100644 --- a/subquery.py +++ b/subquery.py @@ -1,4 +1,4 @@ -from guidance.models import OpenAI as GuidanceOpenAI +from guidance.models import OpenAIChat from llama_index import QueryBundle, ServiceContext from llama_index.core import BaseQueryEngine from llama_index.query_engine import SubQuestionQueryEngine @@ -94,7 +94,8 @@ def query_multiple_source( raise NotImplementedError question_gen = GuidanceQuestionGenerator.from_defaults( - guidance_llm=GuidanceOpenAI("text-davinci-003"), verbose=False + guidance_llm=OpenAIChat("gpt-3.5-turbo"), + verbose=False, ) embed_model = CohereEmbedding() service_context = ServiceContext.from_defaults(embed_model=embed_model)