Skip to content

Commit

Permalink
Merge branch 'GrandeCapo93-develop'
Browse files Browse the repository at this point in the history
  • Loading branch information
nicola-corbellini committed Dec 7, 2023
2 parents e67de94 + fe2a03a commit b38a8e9
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 13 deletions.
80 changes: 71 additions & 9 deletions hyde.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,86 @@
import json
import langchain
from cat.mad_hatter.decorators import hook
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate

from cat.log import log
from cat.mad_hatter.decorators import hook


with open("cat/plugins/ccat_hyde/settings.json", "r") as json_file:
settings = json.load(json_file)
# Keys
HYDE_ANSWER = "hyde_answer"
AVERAGE_EMBEDDING = "average_embedding"


@hook(priority=1)
def cat_recall_query(user_message, cat):

# Acquire settings
settings = cat.mad_hatter.get_plugin().load_settings()
log.debug(f" --------- ACQUIRE SETTINGS ---------")
log.debug(f"settings: {settings}")

# Make a prompt from template
hypothesis_prompt = langchain.PromptTemplate(
hypothesis_prompt = PromptTemplate(
input_variables=["input"],
template=settings["hyde_prompt"]
)

# Run a LLM chain with the user message as input
hypothesis_chain = langchain.chains.LLMChain(prompt=hypothesis_prompt, llm=cat._llm)
hypothesis_chain = LLMChain(prompt=hypothesis_prompt, llm=cat._llm)
answer = hypothesis_chain(user_message)
log(answer, "INFO")
return answer["text"]

# Save HyDE answer in working memory
cat.working_memory[HYDE_ANSWER] = answer["text"]

log.debug("------------- HYDE -------------")
log.debug(f"user message: {user_message}")
log.debug(f"hyde answer: {answer['text']}")

return user_message


# Calculates the average between the user's message embedding and the Hyde response embedding
def _calculate_vector_average(config: dict, cat):

# If hyde answer exists, calculate and set average embedding
if HYDE_ANSWER in cat.working_memory.keys():

# Get user message embedding
user_embedding = config['embedding']

# Calculate hyde embedding from hyde answer
hyde_answer = cat.working_memory[HYDE_ANSWER]
hyde_embedding = cat.embedder.embed_query(hyde_answer)

# Calculate average embedding and stores it into a working memory
average_embedding = [(x + y)/2 for x, y in zip(user_embedding, hyde_embedding)]
cat.working_memory[AVERAGE_EMBEDDING] = average_embedding

log.debug(f" --------- CALCULATE AVERAGE ---------")
log.debug(f"hyde answer: {hyde_answer}")
log.debug(f"user_embedding: {user_embedding}")
log.debug(f"hyde_embedding: {hyde_embedding}")
log.debug(f"average_embedding: {average_embedding}")

# Delete Hyde Answer from working memory
del cat.working_memory[HYDE_ANSWER]

# If average embedding exists, set the embedding
if AVERAGE_EMBEDDING in cat.working_memory.keys():
average_embedding = cat.working_memory[AVERAGE_EMBEDDING]
config['embedding'] = average_embedding

log.debug(f" --------- SET EMBEDDING ---------")
log.debug(f"average_embedding: {average_embedding}")


@hook(priority=1)
def before_cat_recalls_episodic_memories(config: dict, cat):
_calculate_vector_average(config, cat)

@hook(priority=1)
def before_cat_recalls_declarative_memories(config: dict, cat):
_calculate_vector_average(config, cat)

@hook(priority=1)
def before_cat_recalls_procedural_memories(config: dict, cat):
_calculate_vector_average(config, cat)
4 changes: 2 additions & 2 deletions plugin.json
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
{
"name": "Hypothetical Document Embedding",
"version": "0.0.4",
"version": "0.0.5",
"description": "Official plugin of the Cheshire Cat to add the Hypothetical Document Embedding (HyDE) technique",
"author_name": "Nicola Corbellini",
"author_name": "Nicola Corbellini, Salvatore Mirlocca, Massimiliano D'Amico",
"author_url": "",
"plugin_url": "https://github.com/Furrmidable-Crew/ccat_hyde",
"tags": "hyde, llm, cheshire-cat, embedding",
Expand Down
3 changes: 3 additions & 0 deletions settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"hyde_prompt": "You will be given a sentence.\n If the sentence is a question, convert it to a plausible answer. If the sentence does not contain a question, \n just repeat the sentence as is without adding anything to it.\n\n Examples:\n - what furniture there is in my room? --> In my room there is a bed, a wardrobe and a desk with my computer\n - where did you go today --> today I was at school\n - I like ice cream --> I like ice cream\n - how old is Jack --> Jack is 20 years old\n\n Sentence:\n - {input} -->"
}
3 changes: 1 addition & 2 deletions settings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from cat.mad_hatter.decorators import plugin
from pydantic import BaseModel, Field
from cat.mad_hatter.decorators import plugin


class MySettings(BaseModel):
Expand All @@ -20,7 +20,6 @@ class MySettings(BaseModel):
extra={"type": "TextArea"}
)


@plugin
def settings_schema():
return MySettings.schema()

0 comments on commit b38a8e9

Please sign in to comment.