Skip to content

Commit b47c9ba

Browse files
committed
Automatically render user, assistant, system variables
1 parent d2eeaf9 commit b47c9ba

File tree

4 files changed

+64
-12
lines changed

4 files changed

+64
-12
lines changed

docs/reference/special_tokens.md

+21
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ This means that one needs to write a new prompt each time they use a new model,
66
only replacing these special tokens. This is error-prone and leads to duplicated
77
work.
88

9+
10+
## Beginning and end of sequences
11+
912
`prompts` provides special variables in its templates that allows user to use special tokens in their prompts in a model-agnotic way:
1013

1114
```python
@@ -29,3 +32,21 @@ print(a_simple_prompt["google/gemma-2-9b"]("question"))
2932

3033
The registry is currently limited to a few models. Please [open an issue](https://github.com/outlines-dev/prompts/issues) if you
3134
want to use `prompts` with a model that is not currently in the registry.
35+
36+
37+
## Chat and Instruct models
38+
39+
`prompts` also provides special variables `user`, `assistant` and `system` that are related to chat workflows, so you can design prompts with a chat format in a model-agnostic way:
40+
41+
```python
42+
import prompts
43+
44+
45+
@prompts.template
46+
def simple_prompt(favorite: str):
47+
"""{{ bos + user.begin}} What is your favorite {{favorite + '? ' + user.end}}
48+
{{ assistant.begin }}
49+
"""
50+
```
51+
52+
Chat templates are so idiosyncractic, however, that we recommend using the `Chat` class to format according to chat templates.

prompts/templates.py

+8-12
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
import warnings
44
from dataclasses import dataclass, field
55
from functools import lru_cache
6-
from typing import Callable, Dict, Hashable, Optional, Tuple, cast
6+
from typing import Callable, Dict, Hashable, Optional, cast
77

88
from jinja2 import Environment, StrictUndefined
99

10+
from prompts.tokens import SPECIAL_TOKENS, Special
11+
1012

1113
@dataclass
1214
class Template:
@@ -276,17 +278,11 @@ def render(
276278
keep_trailing_newline=True,
277279
undefined=StrictUndefined,
278280
)
279-
env.globals["bos"] = SPECIAL_TOKENS.get(model_name, ("", ""))[0]
280-
env.globals["eos"] = SPECIAL_TOKENS.get(model_name, ("", ""))[1]
281+
env.globals["bos"] = SPECIAL_TOKENS.get(model_name, Special()).sequence.begin
282+
env.globals["eos"] = SPECIAL_TOKENS.get(model_name, Special()).sequence.end
283+
env.globals["user"] = SPECIAL_TOKENS.get(model_name, Special()).user
284+
env.globals["assistant"] = SPECIAL_TOKENS.get(model_name, Special()).assistant
285+
env.globals["system"] = SPECIAL_TOKENS.get(model_name, Special()).system
281286
jinja_template = env.from_string(cleaned_template)
282287

283288
return jinja_template.render(**values)
284-
285-
286-
# (BOS, EOS)
287-
SPECIAL_TOKENS: Dict[Optional[str], Tuple[str, str]] = {
288-
None: ("", ""),
289-
"google/gemma-2-9b": ("<bos>", "<eos>"),
290-
"openai-community/gpt2": ("", "<|endoftext|>"),
291-
"mistralai/Mistral-7B-v0.1": ("<s>", "</s>"),
292-
}

prompts/tokens.py

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from dataclasses import dataclass
2+
from typing import Dict, Optional
3+
4+
5+
@dataclass
6+
class Limits:
7+
begin: str = ""
8+
end: str = ""
9+
10+
11+
@dataclass
12+
class Special:
13+
sequence: Limits = Limits("", "")
14+
user: Limits = Limits("", "")
15+
assistant: Limits = Limits("", "")
16+
system: Limits = Limits("", "")
17+
18+
19+
SPECIAL_TOKENS: Dict[Optional[str], Special] = {
20+
None: Special(),
21+
"google/gemma-2-9b": Special(Limits("<bos>", "<eos>")),
22+
"openai-community/gpt2": Special(Limits("", "<|endoftext|>")),
23+
"mistralai/Mistral-7B-v0.1": Special(Limits("<s>", "</s>")),
24+
"mistralai/Mistral-7B-Instruct-v0.1": Special(
25+
Limits("<s>", "</s>"),
26+
Limits("[INST]", "[/INST]"),
27+
Limits("", "</s>"),
28+
),
29+
}

tests/test_tokens.py

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from prompts.tokens import Special
2+
3+
4+
def test_simple():
5+
special = Special()
6+
assert special.assistant.begin == ""

0 commit comments

Comments
 (0)