1
1
import inspect
2
2
import re
3
+ import warnings
3
4
from dataclasses import dataclass , field
4
5
from functools import lru_cache
5
- from typing import Callable , Dict , Hashable , Optional , cast
6
+ from typing import Callable , Dict , Hashable , Optional , Tuple , cast
6
7
7
8
from jinja2 import Environment , StrictUndefined
8
9
@@ -29,6 +30,8 @@ class Template:
29
30
The template to render.
30
31
signature
31
32
The prompt function's signature.
33
+ model
34
+ The model the `Template` is associated with. Defaults to `None`.
32
35
registry
33
36
Registry that maps function names to their respective `Template`
34
37
instances.
@@ -50,7 +53,7 @@ def __call__(self, *args, **kwargs) -> str:
50
53
"""
51
54
bound_arguments = self .signature .bind (* args , ** kwargs )
52
55
bound_arguments .apply_defaults ()
53
- return render (self .template , ** bound_arguments .arguments )
56
+ return render (self .template , self . model , ** bound_arguments .arguments )
54
57
55
58
def __str__ (self ):
56
59
return self .template
@@ -74,6 +77,7 @@ def __getitem__(self, model_name: str):
74
77
try :
75
78
return self .registry [model_name ]
76
79
except KeyError :
80
+ self .model = model_name
77
81
return self
78
82
79
83
def register (self , model_name : str ):
@@ -140,13 +144,21 @@ def template(fn: Callable) -> Template:
140
144
141
145
142
146
@lru_cache
143
- def render (template : str , ** values : Optional [Dict [str , Hashable ]]) -> str :
147
+ def render (
148
+ template : str ,
149
+ model_name : Optional [str ] = None ,
150
+ ** values : Optional [Dict [str , Hashable ]],
151
+ ) -> str :
144
152
r"""Parse a Jinaj2 template and translate it into an Outlines graph.
145
153
146
154
This function removes extra whitespaces and linebreaks from templates to
147
155
allow users to enter prompts more naturally than if they used Python's
148
156
constructs directly. See the examples for a detailed explanation.
149
157
158
+ We also define the `bos` and `eos` special variables which, when used, will
159
+ be replaced by the model's BOS and EOS tokens respectively. This allows you
160
+ to write prompts that are model-agnostic.
161
+
150
162
Examples
151
163
--------
152
164
@@ -223,6 +235,8 @@ def render(template: str, **values: Optional[Dict[str, Hashable]]) -> str:
223
235
----------
224
236
template
225
237
A string that contains a template written with the Jinja2 syntax.
238
+ model_name
239
+ The name of the model to which the rendered string will be passed.
226
240
**values
227
241
Map from the variables in the template to their value.
228
242
@@ -245,12 +259,34 @@ def render(template: str, **values: Optional[Dict[str, Hashable]]) -> str:
245
259
# used to continue to the next line without linebreak.
246
260
cleaned_template = re .sub (r"(?![\r\n])(\b\s+)" , " " , cleaned_template )
247
261
262
+ # Warn the user when the model is not present in the special token registry
263
+ if model_name not in SPECIAL_TOKENS :
264
+ warnings .warn (
265
+ UserWarning (
266
+ f"The model { model_name } is not present in the special token registry."
267
+ "As a result, EOS and BOS tokens will be rendered as the empty string."
268
+ "Please open an issue: https://github.com/outlines-dev/prompts/issues"
269
+ "And ask for the model to be added to the registry."
270
+ )
271
+ )
272
+
248
273
env = Environment (
249
274
trim_blocks = True ,
250
275
lstrip_blocks = True ,
251
276
keep_trailing_newline = True ,
252
277
undefined = StrictUndefined ,
253
278
)
279
+ env .globals ["bos" ] = SPECIAL_TOKENS .get (model_name , ("" , "" ))[0 ]
280
+ env .globals ["eos" ] = SPECIAL_TOKENS .get (model_name , ("" , "" ))[1 ]
254
281
jinja_template = env .from_string (cleaned_template )
255
282
256
283
return jinja_template .render (** values )
284
+
285
+
286
+ # (BOS, EOS)
287
+ SPECIAL_TOKENS : Dict [Optional [str ], Tuple [str , str ]] = {
288
+ None : ("" , "" ),
289
+ "google/gemma-2-9b" : ("<bos>" , "<eos>" ),
290
+ "openai-community/gpt2" : ("" , "<|endoftext|>" ),
291
+ "mistralai/Mistral-7B-v0.1" : ("<s>" , "</s>" ),
292
+ }
0 commit comments