@@ -261,7 +261,7 @@ def __init__(
261
261
** client_args ,
262
262
)
263
263
264
- def __call__ (self , messages : list [dict ], n_samples : int = 1 ) -> dict :
264
+ def __call__ (self , messages : list [dict ], n_samples : int = 1 , temperature : float = None ) -> dict :
265
265
# Initialize retry tracking attributes
266
266
self .retries = 0
267
267
self .success = False
@@ -271,12 +271,13 @@ def __call__(self, messages: list[dict], n_samples: int = 1) -> dict:
271
271
e = None
272
272
for itr in range (self .max_retry ):
273
273
self .retries += 1
274
+ temperature = temperature if temperature is not None else self .temperature
274
275
try :
275
276
completion = self .client .chat .completions .create (
276
277
model = self .model_name ,
277
278
messages = messages ,
278
279
n = n_samples ,
279
- temperature = self . temperature ,
280
+ temperature = temperature ,
280
281
max_tokens = self .max_tokens ,
281
282
)
282
283
@@ -414,11 +415,10 @@ def __init__(
414
415
super ().__init__ (model_name , n_retry_server )
415
416
if temperature < 1e-3 :
416
417
logging .warning ("Models might behave weirdly when temperature is too low." )
418
+ self .temperature = temperature
417
419
418
420
if token is None :
419
421
token = os .environ ["TGI_TOKEN" ]
420
422
421
423
client = InferenceClient (model = model_url , token = token )
422
- self .llm = partial (
423
- client .text_generation , temperature = temperature , max_new_tokens = max_new_tokens
424
- )
424
+ self .llm = partial (client .text_generation , max_new_tokens = max_new_tokens )
0 commit comments