Skip to content

Commit 038a70a

Browse files
committed
Automatically render special EOS and BOS tokens
We define Jinja2 `eos` and `bos` global variables that are rendered as EOS and BOS tokens.
1 parent 4be9b25 commit 038a70a

File tree

4 files changed

+93
-3
lines changed

4 files changed

+93
-3
lines changed

docs/reference/special_tokens.md

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Handle special tokens
2+
3+
Tokens that indicate the beginnning of a sequence, an end of sequence, that
4+
delineate user and assistant turns in a conversation, etc. are model-specific.
5+
This means that one needs to write a new prompt each time they use a new model,
6+
only replacing these special tokens. This is error-prone and leads to duplicated
7+
work.
8+
9+
`prompts` provides special variables in its templates that allows user to use special tokens in their prompts in a model-agnotic way:
10+
11+
```python
12+
import prompts
13+
14+
15+
@prompts.template
16+
def a_simple_prompt(query: str):
17+
"""{{ bos + query + eos }}"""
18+
19+
20+
print(a_simple_prompt["mistralai/Mistral-7B-v0.1"]("question"))
21+
# <s>question</s>
22+
23+
print(a_simple_prompt["google/gemma-2-9b"]("question"))
24+
# <bos>question<eos>
25+
```
26+
27+
28+
!!! note "Registry"
29+
30+
The registry is currently limited to a few models. Please [open an issue](https://github.com/outlines-dev/prompts/issues) if you
31+
want to use `prompts` with a model that is not currently in the registry.

mkdocs.yml

+1
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,4 @@ nav:
7575
- reference/index.md
7676
- Prompt template: reference/template.md
7777
- Dispatch: reference/dispatch.md
78+
- Special tokens: reference/special_tokens.md

prompts/templates.py

+39-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import inspect
22
import re
3+
import warnings
34
from dataclasses import dataclass, field
45
from functools import lru_cache
5-
from typing import Callable, Dict, Hashable, Optional, cast
6+
from typing import Callable, Dict, Hashable, Optional, Tuple, cast
67

78
from jinja2 import Environment, StrictUndefined
89

@@ -29,6 +30,8 @@ class Template:
2930
The template to render.
3031
signature
3132
The prompt function's signature.
33+
model
34+
The model the `Template` is associated with. Defaults to `None`.
3235
registry
3336
Registry that maps function names to their respective `Template`
3437
instances.
@@ -50,7 +53,7 @@ def __call__(self, *args, **kwargs) -> str:
5053
"""
5154
bound_arguments = self.signature.bind(*args, **kwargs)
5255
bound_arguments.apply_defaults()
53-
return render(self.template, **bound_arguments.arguments)
56+
return render(self.template, self.model, **bound_arguments.arguments)
5457

5558
def __str__(self):
5659
return self.template
@@ -74,6 +77,7 @@ def __getitem__(self, model_name: str):
7477
try:
7578
return self.registry[model_name]
7679
except KeyError:
80+
self.model = model_name
7781
return self
7882

7983
def register(self, model_name: str):
@@ -140,13 +144,21 @@ def template(fn: Callable) -> Template:
140144

141145

142146
@lru_cache
143-
def render(template: str, **values: Optional[Dict[str, Hashable]]) -> str:
147+
def render(
148+
template: str,
149+
model_name: Optional[str] = None,
150+
**values: Optional[Dict[str, Hashable]],
151+
) -> str:
144152
r"""Parse a Jinaj2 template and translate it into an Outlines graph.
145153
146154
This function removes extra whitespaces and linebreaks from templates to
147155
allow users to enter prompts more naturally than if they used Python's
148156
constructs directly. See the examples for a detailed explanation.
149157
158+
We also define the `bos` and `eos` special variables which, when used, will
159+
be replaced by the model's BOS and EOS tokens respectively. This allows you
160+
to write prompts that are model-agnostic.
161+
150162
Examples
151163
--------
152164
@@ -223,6 +235,8 @@ def render(template: str, **values: Optional[Dict[str, Hashable]]) -> str:
223235
----------
224236
template
225237
A string that contains a template written with the Jinja2 syntax.
238+
model_name
239+
The name of the model to which the rendered string will be passed.
226240
**values
227241
Map from the variables in the template to their value.
228242
@@ -245,12 +259,34 @@ def render(template: str, **values: Optional[Dict[str, Hashable]]) -> str:
245259
# used to continue to the next line without linebreak.
246260
cleaned_template = re.sub(r"(?![\r\n])(\b\s+)", " ", cleaned_template)
247261

262+
# Warn the user when the model is not present in the special token registry
263+
if model_name not in SPECIAL_TOKENS:
264+
warnings.warn(
265+
UserWarning(
266+
f"The model {model_name} is not present in the special token registry."
267+
"As a result, EOS and BOS tokens will be rendered as the empty string."
268+
"Please open an issue: https://github.com/outlines-dev/prompts/issues"
269+
"And ask for the model to be added to the registry."
270+
)
271+
)
272+
248273
env = Environment(
249274
trim_blocks=True,
250275
lstrip_blocks=True,
251276
keep_trailing_newline=True,
252277
undefined=StrictUndefined,
253278
)
279+
env.globals["bos"] = SPECIAL_TOKENS.get(model_name, ("", ""))[0]
280+
env.globals["eos"] = SPECIAL_TOKENS.get(model_name, ("", ""))[1]
254281
jinja_template = env.from_string(cleaned_template)
255282

256283
return jinja_template.render(**values)
284+
285+
286+
# (BOS, EOS)
287+
SPECIAL_TOKENS: Dict[Optional[str], Tuple[str, str]] = {
288+
None: ("", ""),
289+
"google/gemma-2-9b": ("<bos>", "<eos>"),
290+
"openai-community/gpt2": ("", "<|endoftext|>"),
291+
"mistralai/Mistral-7B-v0.1": ("<s>", "</s>"),
292+
}

tests/test_templates.py

+22
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ def test_only_code(variable):
183183
return variable
184184

185185

186+
@pytest.mark.filterwarnings("ignore: The model")
186187
def test_dispatch():
187188

188189
@prompts.template
@@ -207,3 +208,24 @@ def simple_prompt_name(query: str):
207208
assert simple_prompt("test") == "test"
208209
assert simple_prompt["gpt2"]("test") == "test"
209210
assert simple_prompt["provider/name"]("test") == "name: test"
211+
212+
213+
def test_special_tokens():
214+
215+
@prompts.template
216+
def simple_prompt(query: str):
217+
"""{{ bos + query + eos }}"""
218+
219+
assert simple_prompt("test") == "test"
220+
assert simple_prompt["openai-community/gpt2"]("test") == "test<|endoftext|>"
221+
assert simple_prompt["mistralai/Mistral-7B-v0.1"]("test") == "<s>test</s>"
222+
223+
224+
def test_warn():
225+
226+
@prompts.template
227+
def simple_prompt():
228+
"""test"""
229+
230+
with pytest.warns(UserWarning, match="not present in the special token"):
231+
simple_prompt["non-existent-model"]()

0 commit comments

Comments
 (0)