Skip to content

Commit e5d5171

Browse files
tweaks for per-call temperature setting
1 parent 0d7df43 commit e5d5171

File tree

2 files changed

+9
-6
lines changed

2 files changed

+9
-6
lines changed

src/agentlab/llm/chat_api.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def __init__(
261261
**client_args,
262262
)
263263

264-
def __call__(self, messages: list[dict], n_samples: int = 1) -> dict:
264+
def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float = None) -> dict:
265265
# Initialize retry tracking attributes
266266
self.retries = 0
267267
self.success = False
@@ -271,12 +271,13 @@ def __call__(self, messages: list[dict], n_samples: int = 1) -> dict:
271271
e = None
272272
for itr in range(self.max_retry):
273273
self.retries += 1
274+
temperature = temperature if temperature is not None else self.temperature
274275
try:
275276
completion = self.client.chat.completions.create(
276277
model=self.model_name,
277278
messages=messages,
278279
n=n_samples,
279-
temperature=self.temperature,
280+
temperature=temperature,
280281
max_tokens=self.max_tokens,
281282
)
282283

@@ -414,11 +415,10 @@ def __init__(
414415
super().__init__(model_name, n_retry_server)
415416
if temperature < 1e-3:
416417
logging.warning("Models might behave weirdly when temperature is too low.")
418+
self.temperature = temperature
417419

418420
if token is None:
419421
token = os.environ["TGI_TOKEN"]
420422

421423
client = InferenceClient(model=model_url, token=token)
422-
self.llm = partial(
423-
client.text_generation, temperature=temperature, max_new_tokens=max_new_tokens
424-
)
424+
self.llm = partial(client.text_generation, max_new_tokens=max_new_tokens)

src/agentlab/llm/huggingface_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,15 @@ def __call__(
5656
self,
5757
messages: list[dict],
5858
n_samples: int = 1,
59+
temperature: float = None,
5960
) -> Union[AIMessage, List[AIMessage]]:
6061
"""
6162
Generate one or more responses for the given messages.
6263
6364
Args:
6465
messages: List of message dictionaries containing the conversation history.
6566
n_samples: Number of independent responses to generate. Defaults to 1.
67+
temperature: The temperature for response sampling. Defaults to None.
6668
6769
Returns:
6870
If n_samples=1, returns a single AIMessage.
@@ -91,7 +93,8 @@ def __call__(
9193
itr = 0
9294
while True:
9395
try:
94-
response = AIMessage(self.llm(prompt))
96+
temperature = temperature if temperature is not None else self.temperature
97+
response = AIMessage(self.llm(prompt, temperature=temperature))
9598
responses.append(response)
9699
break
97100
except Exception as e:

0 commit comments

Comments
 (0)