Skip to content

Commit

Permalink
[Bugfix] add temperature=0 to logprobs and seed args to API models (E…
Browse files Browse the repository at this point in the history
…leutherAI#2149)

* add temperature for log probs

* add seed

* nit

* add new args to test

* added warning for api chat models
  • Loading branch information
baberabb authored and MFajcik committed Aug 1, 2024
1 parent 85867d8 commit 8a4d119
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 1 deletion.
3 changes: 3 additions & 0 deletions lm_eval/models/api_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def _create_payload(
*,
generate: bool = True,
gen_kwargs: Optional[dict] = None,
seed: int = 1234,
**kwargs,
) -> dict:
"""This method is responsible for creating the json payload that will be sent to the API."""
Expand Down Expand Up @@ -334,6 +335,7 @@ def model_call(
self.create_message(messages),
generate=generate,
gen_kwargs=gen_kwargs,
seed=self._seed,
**kwargs,
),
headers=self.header,
Expand Down Expand Up @@ -367,6 +369,7 @@ async def amodel_call(
self.create_message(messages),
generate=generate,
gen_kwargs=gen_kwargs,
seed=self._seed,
**kwargs,
)
cache_method = "generate_until" if generate else "loglikelihood"
Expand Down
15 changes: 14 additions & 1 deletion lm_eval/models/openai_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def _create_payload(
messages: Union[List[List[int]], List[dict], List[str], str],
generate=False,
gen_kwargs: Optional[dict] = None,
seed: int = 1234,
**kwargs,
) -> dict:
if generate:
Expand All @@ -37,14 +38,17 @@ def _create_payload(
"max_tokens": max_tokens,
"temperature": temperature,
"stop": stop,
"seed": seed,
**gen_kwargs,
}
else:
return {
"model": self.model,
"prompt": messages,
"temperature": 0,
"max_tokens": 1,
"logprobs": 1,
"seed": seed,
"echo": True,
}

Expand Down Expand Up @@ -96,6 +100,9 @@ def __init__(
tokenized_requests=False,
**kwargs,
):
eval_logger.warning(
"chat-completions endpoint requires the `--apply_chat_template` flag."
)
super().__init__(
base_url=base_url,
tokenizer_backend=tokenizer_backend,
Expand All @@ -109,7 +116,12 @@ def __init__(
self._batch_size = 1

def _create_payload(
self, messages: List[Dict], generate=False, gen_kwargs: dict = None, **kwargs
self,
messages: List[Dict],
generate=False,
gen_kwargs: dict = None,
seed=1234,
**kwargs,
) -> dict:
gen_kwargs.pop("do_sample", False)
max_tokens = gen_kwargs.pop("max_gen_toks", self._max_gen_toks)
Expand All @@ -123,6 +135,7 @@ def _create_payload(
"max_tokens": max_tokens,
"temperature": temperature,
"stop": stop[:4],
"seed": seed,
**gen_kwargs,
}

Expand Down
8 changes: 8 additions & 0 deletions tests/models/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def test_create_payload_generate(api):
"temperature": 0.7,
"until": ["The End"],
"do_sample": True,
"seed": 1234,
}
payload = api._create_payload(messages, generate=True, gen_kwargs=gen_kwargs)

Expand All @@ -37,6 +38,7 @@ def test_create_payload_generate(api):
"max_tokens": 100,
"temperature": 0.7,
"stop": ["The End"],
"seed": 1234,
}


Expand All @@ -50,6 +52,8 @@ def test_create_payload_loglikelihood(api):
"max_tokens": 1,
"logprobs": 1,
"echo": True,
"temperature": 0,
"seed": 1234,
}


Expand All @@ -66,6 +70,7 @@ def test_create_payload_loglikelihood(api):
"max_tokens": 100,
"temperature": 0.7,
"stop": ["<|endoftext|>"],
"seed": 1234,
},
),
(
Expand All @@ -78,6 +83,7 @@ def test_create_payload_loglikelihood(api):
"max_tokens": 256,
"temperature": 0,
"stop": ["<|endoftext|>"],
"seed": 1234,
},
),
],
Expand Down Expand Up @@ -116,6 +122,8 @@ def test_model_generate_call_usage(
"max_tokens": 1,
"logprobs": 1,
"echo": True,
"seed": 1234,
"temperature": 0,
},
),
],
Expand Down

0 comments on commit 8a4d119

Please sign in to comment.