Skip to content

Commit 1b914c1

Browse files
committed
Make funtion templates real functions
Prompts remplates are currently contained in the docstring of decorated functions. The main issue with this is that prompt templates cannot be composed. In this commit we instead require users to return the prompt template from the function. The template will then automatically be rendered using the values passed to the function. This is very flexible: some variables can be used inside the functions and not be present in the Jinja2 template that is returned, for instance: ```python import prompts @prompts.template def my_template(a, b): prompt = f'This is a first variable {a}' return prompt + "and a second {{b}}" ```
1 parent 05c9d5e commit 1b914c1

File tree

4 files changed

+36
-56
lines changed

4 files changed

+36
-56
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ from prompts import template
2424

2525
@template
2626
def few_shots(instructions, examples, question):
27-
"""{{ instructions }}
27+
return """{{ instructions }}
2828
2929
Examples
3030
--------

docs/reference/template.md

+10-10
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ will pass to the prompt function.
3838

3939
@prompts.template
4040
def greetings(name, question):
41-
"""Hello, {{ name }}!
41+
return """Hello, {{ name }}!
4242
{{ question }}
4343
"""
4444

@@ -62,7 +62,7 @@ If a variable is missing in the function's arguments, Jinja2 will throw an `Unde
6262

6363
@prompts.template
6464
def greetings(name):
65-
"""Hello, {{ surname }}!"""
65+
return """Hello, {{ surname }}!"""
6666

6767
prompt = greetings("user")
6868
```
@@ -94,7 +94,7 @@ Prompt functions are functions, and thus can be imported from other modules:
9494

9595
@prompts.template
9696
def greetings(name, question):
97-
"""Hello, {{ name }}!
97+
return """Hello, {{ name }}!
9898
{{ question }}
9999
"""
100100
```
@@ -128,7 +128,7 @@ keys `question` and `answer` to the prompt function:
128128

129129
@prompts.template
130130
def few_shots(instructions, examples, question):
131-
"""{{ instructions }}
131+
return """{{ instructions }}
132132

133133
Examples
134134
--------
@@ -207,12 +207,12 @@ below does not matter for formatting:
207207

208208
@prompts.template
209209
def prompt1():
210-
"""My prompt
210+
return """My prompt
211211
"""
212212

213213
@prompts.template
214214
def prompt2():
215-
"""
215+
return """
216216
My prompt
217217
"""
218218

@@ -236,20 +236,20 @@ Indentation is relative to the second line of the docstring, and leading spaces
236236

237237
@prompts.template
238238
def example1():
239-
"""First line
239+
return """First line
240240
Second line
241241
"""
242242

243243
@prompts.template
244244
def example2():
245-
"""
245+
return """
246246
Second line
247247
Third line
248248
"""
249249

250250
@prompts.template
251251
def example3():
252-
"""
252+
return """
253253
Second line
254254
Third line
255255
"""
@@ -285,7 +285,7 @@ You can use the backslash `\` to break a long line of text. It will render as a
285285

286286
@prompts.template
287287
def example():
288-
"""
288+
return """
289289
Break in \
290290
several lines \
291291
But respect the indentation

prompts/templates.py

+18-22
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import warnings
44
from dataclasses import dataclass, field
55
from functools import lru_cache
6-
from typing import Callable, Dict, Hashable, Optional, cast
6+
from typing import Callable, Dict, Hashable, Optional
77

88
from jinja2 import Environment, StrictUndefined
99

@@ -15,7 +15,7 @@ class Template:
1515
"""Represents a prompt template.
1616
1717
A prompt template is a callable that, given a Jinja2 template and a set of values,
18-
renders the template using those values. It is recommended to instantiate `Temaplate`
18+
renders the template using those values. It is recommended to instantiate `Template`
1919
using the `template` decorator, which extracts the template from the function's
2020
docstring and its variables from the function's signature.
2121
@@ -40,11 +40,15 @@ class Template:
4040
4141
"""
4242

43-
template: str
4443
signature: inspect.Signature
44+
fn: Callable
4545
model: Optional[str] = None
4646
registry: Dict[str, Callable] = field(default_factory=dict)
4747

48+
def __init__(self, fn: Callable):
49+
self.fn = fn
50+
self.signature = inspect.signature(fn)
51+
4852
def __call__(self, *args, **kwargs) -> str:
4953
"""Render and return the template.
5054
@@ -55,7 +59,10 @@ def __call__(self, *args, **kwargs) -> str:
5559
"""
5660
bound_arguments = self.signature.bind(*args, **kwargs)
5761
bound_arguments.apply_defaults()
58-
return render(self.template, self.model, **bound_arguments.arguments)
62+
63+
template = self.fn(**bound_arguments.arguments)
64+
65+
return render(template, self.model, **bound_arguments.arguments)
5966

6067
def __str__(self):
6168
return self.template
@@ -104,24 +111,23 @@ def template(fn: Callable) -> Template:
104111
manipulation by providing some degree of encapsulation. It uses the `render`
105112
function internally to render templates.
106113
107-
>>> import outlines
114+
>>> import prompts
108115
>>>
109-
>>> @outlines.prompt
116+
>>> @prompts.template
110117
>>> def build_prompt(question):
111-
... "I have a ${question}"
118+
... return "I have a {{question}}"
112119
...
113120
>>> prompt = build_prompt("How are you?")
114121
115122
This API can also be helpful in an "agent" context where parts of the prompt
116123
are set when the agent is initialized and never modified later. In this situation
117124
we can partially apply the prompt function at initialization.
118125
119-
>>> import outlines
120-
>>> import functools as ft
126+
>>> import prompts
121127
...
122-
>>> @outlines.prompt
128+
>>> @prompts.template
123129
... def solve_task(name: str, objective: str, task: str):
124-
... '''Your name is {{name}}.
130+
... return '''Your name is {{name}}.
125131
.. Your overall objective is to {{objective}}.
126132
... Please solve the following task: {{task}}'''
127133
...
@@ -132,17 +138,7 @@ def template(fn: Callable) -> Template:
132138
A `Prompt` callable class which will render the template when called.
133139
134140
"""
135-
signature = inspect.signature(fn)
136-
137-
# The docstring contains the template that will be rendered to be used
138-
# as a prompt to the language model.
139-
docstring = fn.__doc__
140-
if docstring is None:
141-
raise TypeError("Could not find a template in the function's docstring.")
142-
143-
template = cast(str, docstring)
144-
145-
return Template(template, signature)
141+
return Template(fn)
146142

147143

148144
@lru_cache

tests/test_templates.py

+7-23
Original file line numberDiff line numberDiff line change
@@ -129,9 +129,8 @@ def test_render_jinja():
129129
def test_prompt_basic():
130130
@prompts.template
131131
def test_tpl(variable):
132-
"""{{variable}} test"""
132+
return """{{variable}} test"""
133133

134-
assert test_tpl.template == "{{variable}} test"
135134
assert list(test_tpl.signature.parameters.keys()) == ["variable"]
136135

137136
with pytest.raises(TypeError):
@@ -145,7 +144,7 @@ def test_tpl(variable):
145144

146145
@prompts.template
147146
def test_single_quote_tpl(variable):
148-
"${variable} test"
147+
return "{{variable}} test"
149148

150149
p = test_tpl("test")
151150
assert p == "test test"
@@ -154,9 +153,8 @@ def test_single_quote_tpl(variable):
154153
def test_prompt_kwargs():
155154
@prompts.template
156155
def test_kwarg_tpl(var, other_var="other"):
157-
"""{{var}} and {{other_var}}"""
156+
return """{{var}} and {{other_var}}"""
158157

159-
assert test_kwarg_tpl.template == "{{var}} and {{other_var}}"
160158
assert list(test_kwarg_tpl.signature.parameters.keys()) == ["var", "other_var"]
161159

162160
p = test_kwarg_tpl("test")
@@ -169,30 +167,16 @@ def test_kwarg_tpl(var, other_var="other"):
169167
assert p == "test and test"
170168

171169

172-
def test_no_prompt():
173-
with pytest.raises(TypeError, match="template"):
174-
175-
@prompts.template
176-
def test_empty(variable):
177-
pass
178-
179-
with pytest.raises(TypeError, match="template"):
180-
181-
@prompts.template
182-
def test_only_code(variable):
183-
return variable
184-
185-
186170
@pytest.mark.filterwarnings("ignore: The model")
187171
def test_dispatch():
188172

189173
@prompts.template
190174
def simple_prompt(query: str):
191-
"""{{ query }}"""
175+
return """{{ query }}"""
192176

193177
@simple_prompt.register("provider/name")
194178
def simple_prompt_name(query: str):
195-
"""name: {{ query }}"""
179+
return """name: {{ query }}"""
196180

197181
assert list(simple_prompt.registry.keys()) == ["provider/name"]
198182
assert callable(simple_prompt)
@@ -214,7 +198,7 @@ def test_special_tokens():
214198

215199
@prompts.template
216200
def simple_prompt(query: str):
217-
"""{{ bos + query + eos }}"""
201+
return """{{ bos + query + eos }}"""
218202

219203
assert simple_prompt("test") == "test"
220204
assert simple_prompt["openai-community/gpt2"]("test") == "test<|endoftext|>"
@@ -225,7 +209,7 @@ def test_warn():
225209

226210
@prompts.template
227211
def simple_prompt():
228-
"""test"""
212+
return """test"""
229213

230214
with pytest.warns(UserWarning, match="not present in the special token"):
231215
simple_prompt["non-existent-model"]()

0 commit comments

Comments
 (0)