diff --git a/hyde.py b/hyde.py index ffb38d5..de0aa50 100644 --- a/hyde.py +++ b/hyde.py @@ -1,24 +1,86 @@ -import json -import langchain -from cat.mad_hatter.decorators import hook +from langchain.chains import LLMChain +from langchain.prompts import PromptTemplate + from cat.log import log +from cat.mad_hatter.decorators import hook -with open("cat/plugins/ccat_hyde/settings.json", "r") as json_file: - settings = json.load(json_file) +# Keys +HYDE_ANSWER = "hyde_answer" +AVERAGE_EMBEDDING = "average_embedding" @hook(priority=1) def cat_recall_query(user_message, cat): + # Acquire settings + settings = cat.mad_hatter.get_plugin().load_settings() + log.debug(f" --------- ACQUIRE SETTINGS ---------") + log.debug(f"settings: {settings}") + # Make a prompt from template - hypothesis_prompt = langchain.PromptTemplate( + hypothesis_prompt = PromptTemplate( input_variables=["input"], template=settings["hyde_prompt"] ) # Run a LLM chain with the user message as input - hypothesis_chain = langchain.chains.LLMChain(prompt=hypothesis_prompt, llm=cat._llm) + hypothesis_chain = LLMChain(prompt=hypothesis_prompt, llm=cat._llm) answer = hypothesis_chain(user_message) - log(answer, "INFO") - return answer["text"] + + # Save HyDE answer in working memory + cat.working_memory[HYDE_ANSWER] = answer["text"] + + log.debug("------------- HYDE -------------") + log.debug(f"user message: {user_message}") + log.debug(f"hyde answer: {answer['text']}") + + return user_message + + +# Calculates the average between the user's message embedding and the Hyde response embedding +def _calculate_vector_average(config: dict, cat): + + # If hyde answer exists, calculate and set average embedding + if HYDE_ANSWER in cat.working_memory.keys(): + + # Get user message embedding + user_embedding = config['embedding'] + + # Calculate hyde embedding from hyde answer + hyde_answer = cat.working_memory[HYDE_ANSWER] + hyde_embedding = cat.embedder.embed_query(hyde_answer) + + # Calculate average embedding and stores it into a working memory + average_embedding = [(x + y)/2 for x, y in zip(user_embedding, hyde_embedding)] + cat.working_memory[AVERAGE_EMBEDDING] = average_embedding + + log.debug(f" --------- CALCULATE AVERAGE ---------") + log.debug(f"hyde answer: {hyde_answer}") + log.debug(f"user_embedding: {user_embedding}") + log.debug(f"hyde_embedding: {hyde_embedding}") + log.debug(f"average_embedding: {average_embedding}") + + # Delete Hyde Answer from working memory + del cat.working_memory[HYDE_ANSWER] + + # If average embedding exists, set the embedding + if AVERAGE_EMBEDDING in cat.working_memory.keys(): + average_embedding = cat.working_memory[AVERAGE_EMBEDDING] + config['embedding'] = average_embedding + + log.debug(f" --------- SET EMBEDDING ---------") + log.debug(f"average_embedding: {average_embedding}") + + +@hook(priority=1) +def before_cat_recalls_episodic_memories(config: dict, cat): + _calculate_vector_average(config, cat) + +@hook(priority=1) +def before_cat_recalls_declarative_memories(config: dict, cat): + _calculate_vector_average(config, cat) + +@hook(priority=1) +def before_cat_recalls_procedural_memories(config: dict, cat): + _calculate_vector_average(config, cat) diff --git a/plugin.json b/plugin.json index a58d1f2..ab60930 100644 --- a/plugin.json +++ b/plugin.json @@ -1,8 +1,8 @@ { "name": "Hypothetical Document Embedding", - "version": "0.0.4", + "version": "0.0.5", "description": "Official plugin of the Cheshire Cat to add the Hypothetical Document Embedding (HyDE) technique", - "author_name": "Nicola Corbellini", + "author_name": "Nicola Corbellini, Salvatore Mirlocca, Massimiliano D'Amico", "author_url": "", "plugin_url": "https://github.com/Furrmidable-Crew/ccat_hyde", "tags": "hyde, llm, cheshire-cat, embedding", diff --git a/settings.json b/settings.json new file mode 100644 index 0000000..4a80184 --- /dev/null +++ b/settings.json @@ -0,0 +1,3 @@ +{ + "hyde_prompt": "You will be given a sentence.\n If the sentence is a question, convert it to a plausible answer. If the sentence does not contain a question, \n just repeat the sentence as is without adding anything to it.\n\n Examples:\n - what furniture there is in my room? --> In my room there is a bed, a wardrobe and a desk with my computer\n - where did you go today --> today I was at school\n - I like ice cream --> I like ice cream\n - how old is Jack --> Jack is 20 years old\n\n Sentence:\n - {input} -->" +} \ No newline at end of file diff --git a/settings.py b/settings.py index 0801cd9..cc73548 100644 --- a/settings.py +++ b/settings.py @@ -1,5 +1,5 @@ -from cat.mad_hatter.decorators import plugin from pydantic import BaseModel, Field +from cat.mad_hatter.decorators import plugin class MySettings(BaseModel): @@ -20,7 +20,6 @@ class MySettings(BaseModel): extra={"type": "TextArea"} ) - @plugin def settings_schema(): return MySettings.schema()