From e91d0b48ae582412d59fa9ce380ca0f9bcd6383b Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?R=C3=A9mi=20Louf?= <remi@thetypicalset.com>
Date: Fri, 27 Sep 2024 11:36:00 +0200
Subject: [PATCH 1/2] Add special tokens for models

We should download the relevant files from HF. I don't think we can
avoid implementing the Jinja2 templates for each model family though.
Would need to use regular expressions instead of full names (might be slow).
---
 prompts/tokens.py | 29 +++++++++++++++++++++++++++++
 1 file changed, 29 insertions(+)
 create mode 100644 prompts/tokens.py

diff --git a/prompts/tokens.py b/prompts/tokens.py
new file mode 100644
index 0000000..953ad9f
--- /dev/null
+++ b/prompts/tokens.py
@@ -0,0 +1,29 @@
+from dataclasses import dataclass
+from typing import Dict, Optional
+
+
+@dataclass
+class Limits:
+    begin: str = ""
+    end: str = ""
+
+
+@dataclass
+class Special:
+    sequence: Limits = Limits("", "")
+    user: Limits = Limits("", "")
+    assistant: Limits = Limits("", "")
+    system: Limits = Limits("", "")
+
+
+SPECIAL_TOKENS: Dict[Optional[str], Special] = {
+    None: Special(),
+    "google/gemma-2-9b": Special(Limits("<bos>", "<eos>")),
+    "openai-community/gpt2": Special(Limits("", "<|endoftext|>")),
+    "mistralai/Mistral-7B-v0.1": Special(Limits("<s>", "</s>")),
+    "mistralai/Mistral-7B-Instruct-v0.1": Special(
+        Limits("<s>", "</s>"),
+        Limits("[INST]", "[/INST]"),
+        Limits("", "</s>"),
+    ),
+}

From 65546fe608a1e4ef646d9fec60758fdfe1a5b0b3 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?R=C3=A9mi=20Louf?= <remi@thetypicalset.com>
Date: Wed, 31 Jul 2024 11:39:16 +0200
Subject: [PATCH 2/2] Track and format chat history

Given the multiplicity of formats, formatting the prompt for chat
workflows with open models can be a real hassle and is error-prone. In
this PR we introduce a `Chat` class that allows users to track the
conversation and easily print the corresponding prompt.
---
 docs/reference/chat.md | 20 ++++++++++
 mkdocs.yml             |  1 +
 prompts/chat.py        | 86 ++++++++++++++++++++++++++++++++++++++++++
 pyproject.toml         |  4 +-
 tests/test_chat.py     |  7 ++++
 5 files changed, 116 insertions(+), 2 deletions(-)
 create mode 100644 docs/reference/chat.md
 create mode 100644 prompts/chat.py
 create mode 100644 tests/test_chat.py

diff --git a/docs/reference/chat.md b/docs/reference/chat.md
new file mode 100644
index 0000000..0b1b193
--- /dev/null
+++ b/docs/reference/chat.md
@@ -0,0 +1,20 @@
+# Chat history
+
+## Filter message
+
+In some situation you may want to filter the messages before building the prompt, for instance to use RAG. In this case you can subclass `Chat` and override the `filter` method:
+
+
+```python
+from prompts import Chat
+
+class RAGChat(Chat):
+
+    def filter(self):
+        filtered_message = []
+        for message in filtered_message:
+           if message.role == "user"  and "Hi" in message.content:
+               filtered_message.append(message)
+
+        return filtered_messages
+```
diff --git a/mkdocs.yml b/mkdocs.yml
index 9d1f2b8..4c788a2 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -76,3 +76,4 @@ nav:
       - Prompt template: reference/template.md
       - Dispatch: reference/dispatch.md
       - Special tokens: reference/special_tokens.md
+      - Chat History: reference/chat.md
diff --git a/prompts/chat.py b/prompts/chat.py
new file mode 100644
index 0000000..e601432
--- /dev/null
+++ b/prompts/chat.py
@@ -0,0 +1,86 @@
+from dataclasses import dataclass
+from enum import Enum
+from typing import List, Optional
+
+from pydantic import BaseModel
+from typing_extensions import TypedDict
+
+
+class Document(TypedDict):
+    title: str
+    text: str
+
+
+class Role(Enum):
+    system = "system"
+    user = "user"
+    assistant = "assistant"
+
+
+@dataclass
+class Message:
+    role: Role
+    content: str
+
+
+class Chat:
+    def __init__(
+        self,
+        system_msg: Optional[str] = None,
+        tools: Optional[List[BaseModel]] = None,
+        documents: Optional[List[Document]] = None,
+        history: List[Message] = [],
+    ):
+        self.history = history
+        self.system = system_msg
+        self.tools = tools
+        self.documents = documents
+
+    @property
+    def trimmed_history(self):
+        return self.history
+
+    def __add__(self, other: Message):
+        history = self.history
+        history.append(other)
+        return Chat(self.system, self.tools, self.documents, history=history)
+
+    def __radd__(self, other: Message):
+        history = self.history
+        history.append(other)
+        return Chat(self.system, self.tools, self.documents, history=history)
+
+    def __iadd__(self, other: Message):
+        self.history.append(other)
+        return self
+
+    def __getitem__(self, key):
+        if isinstance(key, int):
+            return self.history[key]
+        else:
+            raise KeyError()
+
+    def render(self, model_name: str):
+        """Render the conversation using the model's chat template.
+
+        TODO: Do this ourselves.
+
+        Parameters
+        ----------
+        model_name
+            The name of the model whose chat template we need to use.
+
+        """
+        from transformers import AutoTokenizer
+
+        conversation = []
+        if self.system is not None:
+            conversation.append({"role": "system", "content": self.system})
+        for message in self.trimmed_history:
+            conversation.append({"role": message.role, "content": message.content})
+
+        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
+
+        return self.tokenizer.apply_chat_template(
+            conversation, self.tools, self.documents
+        )
diff --git a/pyproject.toml b/pyproject.toml
index 326a183..b72372c 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -4,7 +4,7 @@ version = "0.1.0"
 description = "Large Language Models prompting library"
 authors = [{name = "The Outlines developers", email = "contact@dottxt.co"}]
 requires-python = ">= 3.8"
-dependencies = ["jinja2"]
+dependencies = ["jinja2", "pydantic", "transformers"]
 
 [build-system]
 requires = ["setuptools>=45", "setuptools_scm[toml]>=6.2"]
@@ -35,5 +35,5 @@ file="README.md"
 content-type = "text/markdown"
 
 [[tool.mypy.overrides]]
-module = ["jinja2", "pytest"]
+module = ["jinja2", "pydantic", "pytest", "transformers"]
 ignore_missing_imports = true
diff --git a/tests/test_chat.py b/tests/test_chat.py
new file mode 100644
index 0000000..295dc30
--- /dev/null
+++ b/tests/test_chat.py
@@ -0,0 +1,7 @@
+from prompts.chat import Chat, Message
+
+
+def test_simple():
+    chat = Chat("system message")
+    new_chat = chat + Message("user", "new user message")
+    new_chat += Message("assistant", "new assistant message")