Skip to content

Commit 1f67a50

Browse files
Multiple output chat and retry function (#171)
* looking good * adding bypass on benchmark version * darglint * wth vscode * forgot return in retry multiple * darglint --------- Co-authored-by: Léo Boisvert <leo.boisvert@hotmail.ca>
1 parent e411f7d commit 1f67a50

File tree

4 files changed

+75
-6
lines changed

4 files changed

+75
-6
lines changed

src/agentlab/experiments/reproducibility_util.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ def _get_repo(module):
1919
return Repo(Path(module.__file__).resolve().parent, search_parent_directories=True)
2020

2121

22-
def _get_benchmark_version(benchmark: bgym.Benchmark) -> str:
22+
def _get_benchmark_version(
23+
benchmark: bgym.Benchmark, allow_bypass_benchmark_version: bool = False
24+
) -> str:
2325
benchmark_name = benchmark.name
2426

2527
if hasattr(benchmark, "get_version"):
@@ -42,7 +44,10 @@ def _get_benchmark_version(benchmark: bgym.Benchmark) -> str:
4244
elif benchmark_name.startswith("assistantbench"):
4345
return metadata.distribution("browsergym.assistantbench").version
4446
else:
45-
raise ValueError(f"Unknown benchmark {benchmark_name}")
47+
if allow_bypass_benchmark_version:
48+
return "bypassed"
49+
else:
50+
raise ValueError(f"Unknown benchmark {benchmark_name}")
4651

4752

4853
def _get_git_username(repo: Repo) -> str:
@@ -183,6 +188,7 @@ def get_reproducibility_info(
183188
"*inspect_results.ipynb",
184189
),
185190
ignore_changes=False,
191+
allow_bypass_benchmark_version=False,
186192
):
187193
"""
188194
Retrieve a dict of information that could influence the reproducibility of an experiment.
@@ -205,7 +211,7 @@ def get_reproducibility_info(
205211
"benchmark": benchmark.name,
206212
"study_id": study_id,
207213
"comment": comment,
208-
"benchmark_version": _get_benchmark_version(benchmark),
214+
"benchmark_version": _get_benchmark_version(benchmark, allow_bypass_benchmark_version),
209215
"date": datetime.now().strftime("%Y-%m-%d_%H-%M-%S"),
210216
"os": f"{platform.system()} ({platform.version()})",
211217
"python_version": platform.python_version(),

src/agentlab/experiments/study.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from abc import ABC, abstractmethod
21
import gzip
32
import logging
43
import pickle
@@ -269,6 +268,7 @@ def set_reproducibility_info(self, strict_reproducibility=False, comment=None):
269268
self.uuid,
270269
ignore_changes=not strict_reproducibility,
271270
comment=comment,
271+
allow_bypass_benchmark_version=not strict_reproducibility,
272272
)
273273
if self.reproducibility_info is not None:
274274
repro.assert_compatible(

src/agentlab/llm/chat_api.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def __init__(
261261
**client_args,
262262
)
263263

264-
def __call__(self, messages: list[dict]) -> dict:
264+
def __call__(self, messages: list[dict], n_samples: int = 1) -> dict:
265265
# Initialize retry tracking attributes
266266
self.retries = 0
267267
self.success = False
@@ -275,6 +275,7 @@ def __call__(self, messages: list[dict]) -> dict:
275275
completion = self.client.chat.completions.create(
276276
model=self.model_name,
277277
messages=messages,
278+
n=n_samples,
278279
temperature=self.temperature,
279280
max_tokens=self.max_tokens,
280281
)
@@ -305,7 +306,10 @@ def __call__(self, messages: list[dict]) -> dict:
305306
):
306307
tracking.TRACKER.instance(input_tokens, output_tokens, cost)
307308

308-
return AIMessage(completion.choices[0].message.content)
309+
if n_samples == 1:
310+
return AIMessage(completion.choices[0].message.content)
311+
else:
312+
return [AIMessage(c.message.content) for c in completion.choices]
309313

310314
def get_stats(self):
311315
return {

src/agentlab/llm/llm_utils.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,65 @@ def retry(
9090
raise ParseError(f"Could not parse a valid value after {n_retry} retries.")
9191

9292

93+
def retry_multiple(
94+
chat: "ChatModel",
95+
messages: "Discussion",
96+
n_retry: int,
97+
parser: callable,
98+
log: bool = True,
99+
num_samples: int = 1,
100+
):
101+
"""Retry querying the chat models with the response from the parser until it
102+
returns a valid value.
103+
104+
If the answer is not valid, it will retry and append to the chat the retry
105+
message. It will stop after `n_retry`.
106+
107+
Note, each retry has to resend the whole prompt to the API. This can be slow
108+
and expensive.
109+
110+
Args:
111+
chat (ChatModel): a ChatModel object taking a list of messages and
112+
returning a list of answers, all in OpenAI format.
113+
messages (list): the list of messages so far. This list will be modified with
114+
the new messages and the retry messages.
115+
n_retry (int): the maximum number of sequential retries.
116+
parser (callable): a function taking a message and retruning a parsed value,
117+
or raising a ParseError
118+
log (bool): whether to log the retry messages.
119+
num_samples (int): the number of samples to generate from the model.
120+
121+
Returns:
122+
list[dict]: the parsed value, with a string at key "action".
123+
124+
Raises:
125+
ParseError: if the parser could not parse the response after n_retry retries.
126+
"""
127+
tries = 0
128+
while tries < n_retry:
129+
answer_list = chat(messages, num_samples=num_samples)
130+
# TODO: could we change this to not use inplace modifications ?
131+
messages.append(answer)
132+
parsed_answers = []
133+
errors = []
134+
for answer in answer_list:
135+
try:
136+
parsed_answers.append(parser(answer["content"]))
137+
except ParseError as parsing_error:
138+
errors.append(str(parsing_error))
139+
# if we have a valid answer, return it
140+
if parsed_answers:
141+
return parsed_answers, tries
142+
else:
143+
tries += 1
144+
if log:
145+
msg = f"Query failed. Retrying {tries}/{n_retry}.\n[LLM]:\n{answer['content']}\n[User]:\n{str(errors)}"
146+
logging.info(msg)
147+
messages.append(dict(role="user", content=str(errors)))
148+
149+
raise ParseError(f"Could not parse a valid value after {n_retry} retries.")
150+
151+
93152
def truncate_tokens(text, max_tokens=8000, start=0, model_name="gpt-4"):
94153
"""Use tiktoken to truncate a text to a maximum number of tokens."""
95154
enc = tiktoken.encoding_for_model(model_name)

0 commit comments

Comments
 (0)