-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #10 from TogetherCrew/feat/subquery-generator
Feat: subquery generator
- Loading branch information
Showing
10 changed files
with
365 additions
and
130 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
CHUNK_SIZE= | ||
COHERE_API_KEY= | ||
D_RETRIEVER_SEARCH= | ||
EMBEDDING_DIM= | ||
K1_RETRIEVER_SEARCH= | ||
K2_RETRIEVER_SEARCH= | ||
MONGODB_HOST= | ||
MONGODB_PASS= | ||
MONGODB_PORT= | ||
MONGODB_USER= | ||
NEO4J_DB= | ||
NEO4J_HOST= | ||
NEO4J_PASSWORD= | ||
NEO4J_PORT= | ||
NEO4J_PROTOCOL= | ||
NEO4J_USER= | ||
OPENAI_API_KEY= | ||
POSTGRES_HOST= | ||
POSTGRES_PASS= | ||
POSTGRES_PORT= | ||
POSTGRES_USER= | ||
RABBIT_HOST= | ||
RABBIT_PASSWORD= | ||
RABBIT_PORT= | ||
RABBIT_USER= |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,145 +1,36 @@ | ||
from bot.retrievers.forum_summary_retriever import ForumBasedSummaryRetriever | ||
from bot.retrievers.process_dates import process_dates | ||
from bot.retrievers.utils.load_hyperparams import load_hyperparams | ||
from llama_index import QueryBundle | ||
from llama_index.vector_stores import ExactMatchFilter, FilterCondition, MetadataFilters | ||
from llama_index.schema import NodeWithScore | ||
from tc_hivemind_backend.embeddings.cohere import CohereEmbedding | ||
from tc_hivemind_backend.pg_vector_access import PGVectorAccess | ||
from utils.query_engine.discord_query_engine import prepare_discord_engine_auto_filter | ||
|
||
|
||
def query_discord( | ||
community_id: str, | ||
query: str, | ||
thread_names: list[str], | ||
channel_names: list[str], | ||
days: list[str], | ||
similarity_top_k: int | None = None, | ||
) -> str: | ||
) -> tuple[str, list[NodeWithScore]]: | ||
""" | ||
query the discord database using filters given | ||
and give an anwer to the given query using the LLM | ||
query the llm using the query engine | ||
Parameters | ||
------------ | ||
guild_id : str | ||
the discord guild data to query | ||
query_engine : BaseQueryEngine | ||
the prepared query engine | ||
query : str | ||
the query (question) of the user | ||
thread_names : list[str] | ||
the given threads to search for | ||
channel_names : list[str] | ||
the given channels to search for | ||
days : list[str] | ||
the given days to search for | ||
similarity_top_k : int | None | ||
the k similar results to use when querying the data | ||
if `None` will load from `.env` file | ||
the string question | ||
Returns | ||
--------- | ||
---------- | ||
response : str | ||
the LLM response given the query | ||
the LLM response | ||
source_nodes : list[llama_index.schema.NodeWithScore] | ||
the source nodes that helped in answering the question | ||
""" | ||
if similarity_top_k is None: | ||
_, similarity_top_k, _ = load_hyperparams() | ||
|
||
table_name = "discord" | ||
dbname = f"community_{community_id}" | ||
|
||
pg_vector = PGVectorAccess(table_name=table_name, dbname=dbname) | ||
|
||
index = pg_vector.load_index() | ||
|
||
thread_filters: list[ExactMatchFilter] = [] | ||
channel_filters: list[ExactMatchFilter] = [] | ||
day_filters: list[ExactMatchFilter] = [] | ||
|
||
for channel in channel_names: | ||
channel_updated = channel.replace("'", "''") | ||
channel_filters.append(ExactMatchFilter(key="channel", value=channel_updated)) | ||
|
||
for thread in thread_names: | ||
thread_updated = thread.replace("'", "''") | ||
thread_filters.append(ExactMatchFilter(key="thread", value=thread_updated)) | ||
|
||
for day in days: | ||
day_filters.append(ExactMatchFilter(key="date", value=day)) | ||
|
||
all_filters: list[ExactMatchFilter] = [] | ||
all_filters.extend(thread_filters) | ||
all_filters.extend(channel_filters) | ||
all_filters.extend(day_filters) | ||
|
||
filters = MetadataFilters(filters=all_filters, condition=FilterCondition.OR) | ||
|
||
query_engine = index.as_query_engine( | ||
filters=filters, similarity_top_k=similarity_top_k | ||
query_engine = prepare_discord_engine_auto_filter( | ||
community_id=community_id, | ||
query=query, | ||
) | ||
|
||
query_bundle = QueryBundle( | ||
query_str=query, embedding=CohereEmbedding().get_text_embedding(text=query) | ||
) | ||
response = query_engine.query(query_bundle) | ||
|
||
return response.response | ||
|
||
|
||
def query_discord_auto_filter( | ||
community_id: str, | ||
query: str, | ||
similarity_top_k: int | None = None, | ||
d: int | None = None, | ||
) -> str: | ||
""" | ||
get the query results and do the filtering automatically. | ||
By automatically we mean, it would first query the summaries | ||
to get the metadata filters | ||
Parameters | ||
----------- | ||
guild_id : str | ||
the discord guild data to query | ||
query : str | ||
the query (question) of the user | ||
similarity_top_k : int | None | ||
the value for the initial summary search | ||
to get the `k2` count simliar nodes | ||
if `None`, then would read from `.env` | ||
d : int | ||
this would make the secondary search (`query_discord`) | ||
to be done on the `metadata.date - d` to `metadata.date + d` | ||
Returns | ||
--------- | ||
response : str | ||
the LLM response given the query | ||
""" | ||
table_name = "discord_summary" | ||
dbname = f"community_{community_id}" | ||
|
||
if d is None: | ||
_, _, d = load_hyperparams() | ||
if similarity_top_k is None: | ||
similarity_top_k, _, _ = load_hyperparams() | ||
|
||
discord_retriever = ForumBasedSummaryRetriever(table_name=table_name, dbname=dbname) | ||
|
||
channels, threads, dates = discord_retriever.retreive_metadata( | ||
query=query, | ||
metadata_group1_key="channel", | ||
metadata_group2_key="thread", | ||
metadata_date_key="date", | ||
similarity_top_k=similarity_top_k, | ||
) | ||
|
||
dates_modified = process_dates(list(dates), d) | ||
|
||
response = query_discord( | ||
community_id=community_id, | ||
query=query, | ||
thread_names=list(threads), | ||
channel_names=list(channels), | ||
days=dates_modified, | ||
) | ||
return response | ||
return response.response, response.source_nodes |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
from guidance.models import OpenAIChat | ||
from llama_index import QueryBundle, ServiceContext | ||
from llama_index.core import BaseQueryEngine | ||
from llama_index.query_engine import SubQuestionQueryEngine | ||
from llama_index.question_gen.guidance_generator import GuidanceQuestionGenerator | ||
from llama_index.schema import NodeWithScore | ||
from llama_index.tools import QueryEngineTool, ToolMetadata | ||
from tc_hivemind_backend.embeddings.cohere import CohereEmbedding | ||
from utils.query_engine import prepare_discord_engine_auto_filter | ||
|
||
|
||
def query_multiple_source( | ||
query: str, | ||
community_id: str, | ||
discord: bool, | ||
discourse: bool, | ||
gdrive: bool, | ||
notion: bool, | ||
telegram: bool, | ||
github: bool, | ||
) -> tuple[str, list[NodeWithScore]]: | ||
""" | ||
query multiple platforms and get an answer from the multiple | ||
Parameters | ||
------------ | ||
query : str | ||
the user question | ||
community_id : str | ||
the community id to get their data | ||
discord : bool | ||
if `True` then add the engine to the subquery_generator | ||
discourse : bool | ||
if `True` then add the engine to the subquery_generator | ||
gdrive : bool | ||
if `True` then add the engine to the subquery_generator | ||
notion : bool | ||
if `True` then add the engine to the subquery_generator | ||
telegram : bool | ||
if `True` then add the engine to the subquery_generator | ||
github : bool | ||
if `True` then add the engine to the subquery_generator | ||
Returns | ||
-------- | ||
response : str, | ||
the response to the user query from the LLM | ||
using the engines of the given platforms (pltform equal to True) | ||
source_nodes : list[NodeWithScore] | ||
the list of nodes that were source of answering | ||
""" | ||
query_engine_tools: list[QueryEngineTool] = [] | ||
tools: list[ToolMetadata] = [] | ||
|
||
discord_query_engine: BaseQueryEngine | ||
# discourse_query_engine: BaseQueryEngine | ||
# gdrive_query_engine: BaseQueryEngine | ||
# notion_query_engine: BaseQueryEngine | ||
# telegram_query_engine: BaseQueryEngine | ||
# github_query_engine: BaseQueryEngine | ||
|
||
# query engine perparation | ||
# tools_metadata and query_engine_tools | ||
if discord: | ||
discord_query_engine = prepare_discord_engine_auto_filter( | ||
community_id, | ||
query, | ||
similarity_top_k=None, | ||
d=None, | ||
) | ||
tool_metadata = ToolMetadata( | ||
name="Discord", | ||
description="Contains messages and summaries of conversations from the Discord platform of the community", | ||
) | ||
|
||
tools.append(tool_metadata) | ||
query_engine_tools.append( | ||
QueryEngineTool( | ||
query_engine=discord_query_engine, | ||
metadata=tool_metadata, | ||
) | ||
) | ||
|
||
if discourse: | ||
raise NotImplementedError | ||
if gdrive: | ||
raise NotImplementedError | ||
if notion: | ||
raise NotImplementedError | ||
if telegram: | ||
raise NotImplementedError | ||
if github: | ||
raise NotImplementedError | ||
|
||
question_gen = GuidanceQuestionGenerator.from_defaults( | ||
guidance_llm=OpenAIChat("gpt-3.5-turbo"), | ||
verbose=False, | ||
) | ||
embed_model = CohereEmbedding() | ||
service_context = ServiceContext.from_defaults(embed_model=embed_model) | ||
s_engine = SubQuestionQueryEngine.from_defaults( | ||
question_gen=question_gen, | ||
query_engine_tools=query_engine_tools, | ||
use_async=False, | ||
service_context=service_context, | ||
) | ||
query_embedding = embed_model.get_text_embedding(text=query) | ||
response = s_engine.query(QueryBundle(query_str=query, embedding=query_embedding)) | ||
|
||
return response.response, response.source_nodes |
Oops, something went wrong.