Skip to content

Commit b65dc33

Browse files
committed
Track and format chat history
Given the multiplicity of formats, formatting the prompt for chat workflows with open models can be a real hassle and is error-prone. In this PR we introduce a `Chat` class that allows users to track the conversation and easily print the corresponding prompt.
1 parent e91d0b4 commit b65dc33

File tree

5 files changed

+111
-2
lines changed

5 files changed

+111
-2
lines changed

docs/reference/chat.md

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Chat history
2+
3+
## Filter message
4+
5+
In some situation you may want to filter the messages before building the prompt, for instance to use RAG. In this case you can subclass `Chat` and override the `filter` method:
6+
7+
8+
```python
9+
from prompts import Chat
10+
11+
class RAGChat(Chat):
12+
13+
def filter(self):
14+
filtered_message = []
15+
for message in filtered_message:
16+
if message.role == "user" and "Hi" in message.content:
17+
filtered_message.append(message)
18+
19+
return filtered_messages
20+
```

mkdocs.yml

+1
Original file line numberDiff line numberDiff line change
@@ -76,3 +76,4 @@ nav:
7676
- Prompt template: reference/template.md
7777
- Dispatch: reference/dispatch.md
7878
- Special tokens: reference/special_tokens.md
79+
- Chat History: reference/chat.md

prompts/chat.py

+81
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
from dataclasses import dataclass
2+
from enum import Enum
3+
from typing import List, Optional
4+
5+
from pydantic import BaseModel
6+
from typing_extensions import TypedDict
7+
8+
9+
class Document(TypedDict):
10+
title: str
11+
text: str
12+
13+
14+
class Role(Enum):
15+
system = "system"
16+
user = "user"
17+
assistant = "assistant"
18+
19+
20+
@dataclass
21+
class Message:
22+
role: Role
23+
content: str
24+
25+
26+
class Chat:
27+
28+
def __init__(
29+
self,
30+
system_msg: Optional[str] = None,
31+
tools: Optional[List[BaseModel]] = None,
32+
documents: Optional[List[Document]] = None,
33+
history: List[Message] = [],
34+
):
35+
self.history = history
36+
self.system = system_msg
37+
self.tools = tools
38+
self.documents = documents
39+
40+
@property
41+
def trimmed_history(self):
42+
return self.history
43+
44+
def __add__(self, other: Message):
45+
history = self.history
46+
history.append(other)
47+
return Chat(self.system, self.tools, self.documents, history=history)
48+
49+
def __iadd__(self, other: Message):
50+
self.history.append(other)
51+
52+
def __getitem__(self, key):
53+
if isinstance(key, int):
54+
return self.history[key]
55+
else:
56+
raise KeyError()
57+
58+
def render(self, model_name: str):
59+
"""Render the conversation using the model's chat template.
60+
61+
TODO: Do this ourselves.
62+
63+
Parameters
64+
----------
65+
model_name
66+
The name of the model whose chat template we need to use.
67+
68+
"""
69+
from transformers import AutoTokenizer
70+
71+
conversation = []
72+
if self.system is not None:
73+
conversation.append({"role": "system", "content": self.system})
74+
for message in self.trimmed_history:
75+
conversation.append({"role": message.role, "content": message.content})
76+
77+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
78+
79+
return self.tokenizer.apply_chat_template(
80+
conversation, self.tools, self.documents
81+
)

pyproject.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ version = "0.1.0"
44
description = "Large Language Models prompting library"
55
authors = [{name = "The Outlines developers", email = "contact@dottxt.co"}]
66
requires-python = ">= 3.8"
7-
dependencies = ["jinja2"]
7+
dependencies = ["jinja2", "pydantic", "transformers"]
88

99
[build-system]
1010
requires = ["setuptools>=45", "setuptools_scm[toml]>=6.2"]
@@ -35,5 +35,5 @@ file="README.md"
3535
content-type = "text/markdown"
3636

3737
[[tool.mypy.overrides]]
38-
module = ["jinja2", "pytest"]
38+
module = ["jinja2", "pydantic", "pytest", "transformers"]
3939
ignore_missing_imports = true

tests/test_chat.py

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from prompts.chat import Chat, Message
2+
3+
4+
def test_simple():
5+
chat = Chat("system message")
6+
new_chat = chat + Message("user", "new user message")
7+
new_chat += Message("assistant", "new assistant message")

0 commit comments

Comments
 (0)