Skip to content

Commit

Permalink
Context and Memory to file cache - part - 2
Browse files Browse the repository at this point in the history
  • Loading branch information
yorevs committed Nov 1, 2024
1 parent 0400bcd commit 2037dc5
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 36 deletions.
6 changes: 3 additions & 3 deletions src/main/askai/core/commander/commands/history_cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,10 @@ def context_forget(context: str | None = None) -> None:
if context := context if context != "ALL" else None:
shared.context.clear(*(re.split(r"[;,|]", context.upper())))
else:
shared.context.forget()
shared.context.forget() # Clear the context
shared.memory.clear() # Also clear the chat memory
text_formatter.commander_print(
f"Context %GREEN%'{context.upper() if context else 'ALL'}'%NC% has been cleared!"
)
f"Context %GREEN%'{context.upper() if context else 'ALL'}'%NC% has been cleared!")

@staticmethod
def context_copy(name: str | None = None) -> None:
Expand Down
12 changes: 3 additions & 9 deletions src/main/askai/core/component/cache_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,11 +213,9 @@ def read_context(self) -> list[str]:
:return: A list of context entries retrieved from the cache."""
flags: int = re.MULTILINE | re.DOTALL | re.IGNORECASE
context: str = ASKAI_CONTEXT_FILE.read_text()
entries = list(
return list(
filter(str.__len__, map(str.strip, re.split(r"(human|assistant|system):", context, flags=flags))))

return []

def save_memory(self, memory: list[BaseMessage] = None) -> None:
"""Save the context window entries into the context file.
:param memory: A list of memory entries to be saved.
Expand All @@ -230,16 +228,12 @@ def _get_role_(msg: BaseMessage) -> str:
with open(str(ASKAI_MEMORY_FILE), 'w', encoding=Charset.UTF_8.val) as f_hist:
list(map(lambda m: f_hist.write(ensure_endswith(os.linesep, f"{_get_role_(m)}: {m.content}")), memory))

def read_memory(self) -> list[BaseMessage]:
def read_memory(self) -> list[str]:
"""TODO"""

flags: int = re.MULTILINE | re.DOTALL | re.IGNORECASE
memory: str = ASKAI_MEMORY_FILE.read_text()
memories = list(
return list(
filter(str.__len__, map(str.strip, re.split(r"(human|assistant|system):", memory, flags=flags))))

return []



assert (cache := CacheService().INSTANCE) is not None
42 changes: 23 additions & 19 deletions src/main/askai/core/support/chat_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,17 @@
Copyright (c) 2024, HomeSetup
"""

from askai.core.component.cache_service import cache
from askai.exception.exceptions import TokenLengthExceeded
import os
from collections import defaultdict, deque, namedtuple
from functools import partial, reduce
from typing import Any, AnyStr, get_args, Literal, Optional, TypeAlias

from hspylib.core.preconditions import check_argument
from langchain_community.chat_message_histories.in_memory import ChatMessageHistory
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from typing import Any, AnyStr, get_args, Literal, Optional, TypeAlias

import os
import re
from askai.core.component.cache_service import cache
from askai.exception.exceptions import TokenLengthExceeded

ChatRoles: TypeAlias = Literal["system", "human", "assistant"]

Expand All @@ -46,20 +46,24 @@ def __init__(self, token_limit: int, max_context_size: int):

def __str__(self):
ln: str = os.linesep
return ln.join(f"'{k}': '{v}'" + ln for k, v in self._store.items())
return ln.join(f"'{k}': '{v}'" + ln for k, v in self.store.items())

def __getitem__(self, key) -> deque[ContextEntry]:
return self._store[key]
return self.store[key]

def __iter__(self):
return zip(self._store.keys().__iter__(), self._store.values().__iter__())
return zip(self.store.keys().__iter__(), self.store.values().__iter__())

def __len__(self):
return self._store.__len__()
return self.store.__len__()

@property
def keys(self) -> list[AnyStr]:
return [str(k) for k in self._store.keys()]
return [str(k) for k in self.store.keys()]

@property
def store(self) -> dict[Any, deque]:
return self._store

@property
def max_context_size(self) -> int:
Expand All @@ -79,7 +83,7 @@ def push(self, key: str, content: Any, role: ChatRoles = "human") -> ContextRaw:
check_argument(role in get_args(ChatRoles), f"Invalid ChatRole: '{role}'")
if (token_length := (self.length(key)) + len(content)) > self._token_limit:
raise TokenLengthExceeded(f"Required token length={token_length} limit={self._token_limit}")
if (entry := ContextEntry(role, content.strip())) not in (ctx := self._store[key]):
if (entry := ContextEntry(role, content.strip())) not in (ctx := self.store[key]):
ctx.append(entry)

return self.get(key)
Expand All @@ -90,7 +94,7 @@ def get(self, key: str) -> ContextRaw:
:return: The context message associated with the key.
"""

return [{"role": ctx.role, "content": ctx.content} for ctx in self._store[key]] or []
return [{"role": ctx.role, "content": ctx.content} for ctx in self.store[key]] or []

def set(self, key: str, content: Any, role: ChatRoles = "human") -> ContextRaw:
"""Set the context message in the chat with the specified role.
Expand All @@ -109,7 +113,7 @@ def remove(self, key: str, index: int) -> Optional[str]:
:return: The removed message if successful, otherwise None.
"""
val = None
if ctx := self._store[key]:
if ctx := self.store[key]:
if index < len(ctx):
val = ctx[index]
del ctx[index]
Expand All @@ -120,7 +124,7 @@ def length(self, key: str):
:param key: The identifier for the context.
:return: The length of the context (e.g., number of content entries).
"""
ctx = self._store[key]
ctx = self.store[key]
return reduce(lambda total, e: total + len(e.content), ctx, 0) if len(ctx) > 0 else 0

def join(self, *keys: str) -> LangChainContext:
Expand Down Expand Up @@ -159,10 +163,10 @@ def clear(self, *keys: str) -> int:
"""

count = 0
contexts = list(keys or self._store.keys())
contexts = list(keys or self.store.keys())
while contexts and (key := contexts.pop()):
if key in self._store:
del self._store[key]
if key in self.store:
del self.store[key]
count += 1
return count

Expand All @@ -178,10 +182,10 @@ def size(self, key: str) -> int:
:return: The number of entries in the context.
"""

return len(self._store[key])
return len(self.store[key])

def save(self) -> None:
"""Save the current context window to the cache."""
ctx: LangChainContext = self.join(*self._store.keys())
ctx: LangChainContext = self.join(*self.store.keys())
ctx_str: list[str] = [f"{role}: {msg}" for role, msg in ctx]
cache.save_context(ctx_str)
13 changes: 8 additions & 5 deletions src/main/askai/core/support/shared_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from askai.core.component.recorder import recorder
from askai.core.engine.ai_engine import AIEngine
from askai.core.engine.engine_factory import EngineFactory
from askai.core.support.chat_context import ChatContext
from askai.core.support.chat_context import ChatContext, ContextEntry
from askai.core.support.utilities import display_text

LOGGER_NAME: str = 'Askai-Taius'
Expand Down Expand Up @@ -162,8 +162,10 @@ def create_context(self, token_limit: int) -> ChatContext:
if self._context is None:
self._context = ChatContext(token_limit, configs.max_short_memory_size)
if configs.is_keep_context:
# TODO Add to the context.
ctx = cache.read_context()
entries: list[str] = cache.read_context()
for role, content in zip(entries[::2], entries[1::2]):
ctx = self._context.store["HISTORY"]
ctx.append(ContextEntry(role, content))
return self._context

def create_memory(self, memory_key: str = "chat_history") -> ConversationBufferWindowMemory:
Expand All @@ -175,8 +177,9 @@ def create_memory(self, memory_key: str = "chat_history") -> ConversationBufferW
self._memory = ConversationBufferWindowMemory(
memory_key=memory_key, k=configs.max_short_memory_size, return_messages=True)
if configs.is_keep_context:
# TODO Add to the memory the questions and answers.
mem = cache.read_memory()
entries: list[str] = cache.read_memory()
for role, content in zip(entries[::2], entries[1::2]):
self._memory.chat_memory.add_message(self.context.LANGCHAIN_ROLE_MAP[role](content))
return self._memory

def input_text(self, input_prompt: str, placeholder: str | None = None) -> Optional[str]:
Expand Down

0 comments on commit 2037dc5

Please sign in to comment.