From 2037dc54acfe76ccc5606ac587b93ab795e5eb88 Mon Sep 17 00:00:00 2001 From: Hugo Saporetti Junior Date: Fri, 1 Nov 2024 19:18:47 -0300 Subject: [PATCH] Context and Memory to file cache - part - 2 --- .../core/commander/commands/history_cmd.py | 6 +-- .../askai/core/component/cache_service.py | 12 ++---- src/main/askai/core/support/chat_context.py | 42 ++++++++++--------- .../askai/core/support/shared_instances.py | 13 +++--- 4 files changed, 37 insertions(+), 36 deletions(-) diff --git a/src/main/askai/core/commander/commands/history_cmd.py b/src/main/askai/core/commander/commands/history_cmd.py index d8f7e217..6324d0ab 100644 --- a/src/main/askai/core/commander/commands/history_cmd.py +++ b/src/main/askai/core/commander/commands/history_cmd.py @@ -59,10 +59,10 @@ def context_forget(context: str | None = None) -> None: if context := context if context != "ALL" else None: shared.context.clear(*(re.split(r"[;,|]", context.upper()))) else: - shared.context.forget() + shared.context.forget() # Clear the context + shared.memory.clear() # Also clear the chat memory text_formatter.commander_print( - f"Context %GREEN%'{context.upper() if context else 'ALL'}'%NC% has been cleared!" - ) + f"Context %GREEN%'{context.upper() if context else 'ALL'}'%NC% has been cleared!") @staticmethod def context_copy(name: str | None = None) -> None: diff --git a/src/main/askai/core/component/cache_service.py b/src/main/askai/core/component/cache_service.py index 65c925fa..7ee935e0 100644 --- a/src/main/askai/core/component/cache_service.py +++ b/src/main/askai/core/component/cache_service.py @@ -213,11 +213,9 @@ def read_context(self) -> list[str]: :return: A list of context entries retrieved from the cache.""" flags: int = re.MULTILINE | re.DOTALL | re.IGNORECASE context: str = ASKAI_CONTEXT_FILE.read_text() - entries = list( + return list( filter(str.__len__, map(str.strip, re.split(r"(human|assistant|system):", context, flags=flags)))) - return [] - def save_memory(self, memory: list[BaseMessage] = None) -> None: """Save the context window entries into the context file. :param memory: A list of memory entries to be saved. @@ -230,16 +228,12 @@ def _get_role_(msg: BaseMessage) -> str: with open(str(ASKAI_MEMORY_FILE), 'w', encoding=Charset.UTF_8.val) as f_hist: list(map(lambda m: f_hist.write(ensure_endswith(os.linesep, f"{_get_role_(m)}: {m.content}")), memory)) - def read_memory(self) -> list[BaseMessage]: + def read_memory(self) -> list[str]: """TODO""" - flags: int = re.MULTILINE | re.DOTALL | re.IGNORECASE memory: str = ASKAI_MEMORY_FILE.read_text() - memories = list( + return list( filter(str.__len__, map(str.strip, re.split(r"(human|assistant|system):", memory, flags=flags)))) - return [] - - assert (cache := CacheService().INSTANCE) is not None diff --git a/src/main/askai/core/support/chat_context.py b/src/main/askai/core/support/chat_context.py index 5caebb13..9c1ba4d2 100644 --- a/src/main/askai/core/support/chat_context.py +++ b/src/main/askai/core/support/chat_context.py @@ -13,17 +13,17 @@ Copyright (c) 2024, HomeSetup """ -from askai.core.component.cache_service import cache -from askai.exception.exceptions import TokenLengthExceeded +import os from collections import defaultdict, deque, namedtuple from functools import partial, reduce +from typing import Any, AnyStr, get_args, Literal, Optional, TypeAlias + from hspylib.core.preconditions import check_argument from langchain_community.chat_message_histories.in_memory import ChatMessageHistory from langchain_core.messages import AIMessage, HumanMessage, SystemMessage -from typing import Any, AnyStr, get_args, Literal, Optional, TypeAlias -import os -import re +from askai.core.component.cache_service import cache +from askai.exception.exceptions import TokenLengthExceeded ChatRoles: TypeAlias = Literal["system", "human", "assistant"] @@ -46,20 +46,24 @@ def __init__(self, token_limit: int, max_context_size: int): def __str__(self): ln: str = os.linesep - return ln.join(f"'{k}': '{v}'" + ln for k, v in self._store.items()) + return ln.join(f"'{k}': '{v}'" + ln for k, v in self.store.items()) def __getitem__(self, key) -> deque[ContextEntry]: - return self._store[key] + return self.store[key] def __iter__(self): - return zip(self._store.keys().__iter__(), self._store.values().__iter__()) + return zip(self.store.keys().__iter__(), self.store.values().__iter__()) def __len__(self): - return self._store.__len__() + return self.store.__len__() @property def keys(self) -> list[AnyStr]: - return [str(k) for k in self._store.keys()] + return [str(k) for k in self.store.keys()] + + @property + def store(self) -> dict[Any, deque]: + return self._store @property def max_context_size(self) -> int: @@ -79,7 +83,7 @@ def push(self, key: str, content: Any, role: ChatRoles = "human") -> ContextRaw: check_argument(role in get_args(ChatRoles), f"Invalid ChatRole: '{role}'") if (token_length := (self.length(key)) + len(content)) > self._token_limit: raise TokenLengthExceeded(f"Required token length={token_length} limit={self._token_limit}") - if (entry := ContextEntry(role, content.strip())) not in (ctx := self._store[key]): + if (entry := ContextEntry(role, content.strip())) not in (ctx := self.store[key]): ctx.append(entry) return self.get(key) @@ -90,7 +94,7 @@ def get(self, key: str) -> ContextRaw: :return: The context message associated with the key. """ - return [{"role": ctx.role, "content": ctx.content} for ctx in self._store[key]] or [] + return [{"role": ctx.role, "content": ctx.content} for ctx in self.store[key]] or [] def set(self, key: str, content: Any, role: ChatRoles = "human") -> ContextRaw: """Set the context message in the chat with the specified role. @@ -109,7 +113,7 @@ def remove(self, key: str, index: int) -> Optional[str]: :return: The removed message if successful, otherwise None. """ val = None - if ctx := self._store[key]: + if ctx := self.store[key]: if index < len(ctx): val = ctx[index] del ctx[index] @@ -120,7 +124,7 @@ def length(self, key: str): :param key: The identifier for the context. :return: The length of the context (e.g., number of content entries). """ - ctx = self._store[key] + ctx = self.store[key] return reduce(lambda total, e: total + len(e.content), ctx, 0) if len(ctx) > 0 else 0 def join(self, *keys: str) -> LangChainContext: @@ -159,10 +163,10 @@ def clear(self, *keys: str) -> int: """ count = 0 - contexts = list(keys or self._store.keys()) + contexts = list(keys or self.store.keys()) while contexts and (key := contexts.pop()): - if key in self._store: - del self._store[key] + if key in self.store: + del self.store[key] count += 1 return count @@ -178,10 +182,10 @@ def size(self, key: str) -> int: :return: The number of entries in the context. """ - return len(self._store[key]) + return len(self.store[key]) def save(self) -> None: """Save the current context window to the cache.""" - ctx: LangChainContext = self.join(*self._store.keys()) + ctx: LangChainContext = self.join(*self.store.keys()) ctx_str: list[str] = [f"{role}: {msg}" for role, msg in ctx] cache.save_context(ctx_str) diff --git a/src/main/askai/core/support/shared_instances.py b/src/main/askai/core/support/shared_instances.py index a124b298..2a404c86 100644 --- a/src/main/askai/core/support/shared_instances.py +++ b/src/main/askai/core/support/shared_instances.py @@ -35,7 +35,7 @@ from askai.core.component.recorder import recorder from askai.core.engine.ai_engine import AIEngine from askai.core.engine.engine_factory import EngineFactory -from askai.core.support.chat_context import ChatContext +from askai.core.support.chat_context import ChatContext, ContextEntry from askai.core.support.utilities import display_text LOGGER_NAME: str = 'Askai-Taius' @@ -162,8 +162,10 @@ def create_context(self, token_limit: int) -> ChatContext: if self._context is None: self._context = ChatContext(token_limit, configs.max_short_memory_size) if configs.is_keep_context: - # TODO Add to the context. - ctx = cache.read_context() + entries: list[str] = cache.read_context() + for role, content in zip(entries[::2], entries[1::2]): + ctx = self._context.store["HISTORY"] + ctx.append(ContextEntry(role, content)) return self._context def create_memory(self, memory_key: str = "chat_history") -> ConversationBufferWindowMemory: @@ -175,8 +177,9 @@ def create_memory(self, memory_key: str = "chat_history") -> ConversationBufferW self._memory = ConversationBufferWindowMemory( memory_key=memory_key, k=configs.max_short_memory_size, return_messages=True) if configs.is_keep_context: - # TODO Add to the memory the questions and answers. - mem = cache.read_memory() + entries: list[str] = cache.read_memory() + for role, content in zip(entries[::2], entries[1::2]): + self._memory.chat_memory.add_message(self.context.LANGCHAIN_ROLE_MAP[role](content)) return self._memory def input_text(self, input_prompt: str, placeholder: str | None = None) -> Optional[str]: