Skip to content

Commit

Permalink
Merge pull request #112 from dot-agent/fixing-memory
Browse files Browse the repository at this point in the history
Fixing memory
  • Loading branch information
anubrag authored Jan 10, 2024
2 parents cd35659 + eb414ba commit f6007bc
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 16 deletions.
15 changes: 7 additions & 8 deletions nextpy/ai/engine/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def __init__(
log=None,
memory=None,
memory_threshold=1,
memory_llm=None,
**kwargs,
):
"""Create a new Program object from a program string.
Expand Down Expand Up @@ -222,19 +223,19 @@ def __init__(
self.await_missing = await_missing
self.log = log
self.memory = memory
self.memory_threshold = memory_threshold

if self.memory is not None:
if not isinstance(self.memory, BaseMemory):
raise TypeError("Memory type is not compatible")

# If user has not passed separate llm for memory, use engine's
self.memory.llm = memory_llm if memory_llm is not None else llm
self.memory.memory_threshold = memory_threshold

if self._text.find("ConversationHistory") == -1:
self._text = add_variable(self._text)

if self.memory is not None:
ConversationHistory = self.memory.get_memory(
memory_threshold=self.memory_threshold
)
ConversationHistory = self.memory.get_memory()
kwargs["ConversationHistory"] = ConversationHistory
self.ConversationHistory = ConversationHistory

Expand Down Expand Up @@ -382,9 +383,7 @@ def __call__(self, from_agent=False, **kwargs):

if self.memory is not None:
if not from_agent:
self.ConversationHistory = self.memory.get_memory(
memory_threshold=self.memory_threshold
)
self.ConversationHistory = self.memory.get_memory()
kwargs["ConversationHistory"] = self.ConversationHistory

log.debug(f"in __call__ with kwargs: {kwargs}")
Expand Down
2 changes: 2 additions & 0 deletions nextpy/ai/memory/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ class BaseMemory(ABC, BaseModel):
This class defines the interface for setting, getting, and checking existence of data in memory.
"""

llm: Any = None
memory_threshold: int = 1
# All memories are stored in this list
messages: List[Dict[BaseMessage, Any]] = Field(default=list())

Expand Down
3 changes: 0 additions & 3 deletions nextpy/ai/memory/read_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,6 @@ class ReadOnlyMemory(BaseMemory):

memory: BaseMemory

def __init__(self, memory: BaseMemory):
self.memory = memory

def add_memory(self, prompt: str, llm_response: Any) -> None:
"""cannot edit a read only memory."""
pass
Expand Down
9 changes: 4 additions & 5 deletions nextpy/ai/memory/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def add_memory(self, prompt: str, llm_response: Any) -> None:

def get_memory(self, **kwargs) -> str:
"""Retrieve entire memory from the store."""
# Create llm instance

llm = engine.llms.OpenAI(model="gpt-3.5-turbo")

new_messages = [
Expand All @@ -59,10 +59,9 @@ def get_memory(self, **kwargs) -> str:
+ "\n"
)
self.messages_in_summary.append(conversation)

summarizer = engine(template=SUMMARIZER_TEMPLATE, llm=llm, stream=False)
summarizer = engine(template=SUMMARIZER_TEMPLATE, stream=False, llm=self.llm if self.llm is not None else llm)
summarized_memory = summarizer(
summary=self.current_summary, new_lines=messages_to_text
summary=self.current_summary, new_lines=messages_to_text,
)
self.current_summary = extract_text(summarized_memory.text)
summarized_memory = "Current conversation:\n" + self.current_summary
Expand Down Expand Up @@ -96,7 +95,7 @@ def remove_memory(self, prompt: str, llm=Any) -> None:
self.messages_in_summary.append(conversation)

summarizer = engine(
template=SUMMARIZER_TEMPLATE, llm=llm, stream=False
template=SUMMARIZER_TEMPLATE, llm=self.llm, stream=False
)
summarized_memory = summarizer(
summary="", new_lines=messages_to_text
Expand Down

0 comments on commit f6007bc

Please sign in to comment.