Skip to content

Commit b01a7f8

Browse files
authored
Add buffer in the maximum number of tokens generated (to fix #353) (#354)
* Add buffer in the maximum number of tokens generated * Add the token_buffer consistently in all subclasses
1 parent c39b68a commit b01a7f8

File tree

2 files changed

+18
-4
lines changed

2 files changed

+18
-4
lines changed

prompt2model/utils/api_tools.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def generate_one_completion(
7272
temperature: float = 0,
7373
presence_penalty: float = 0,
7474
frequency_penalty: float = 0,
75+
token_buffer: int = 300,
7576
) -> openai.Completion:
7677
"""Generate a chat completion using an API-based model.
7778
@@ -86,16 +87,21 @@ def generate_one_completion(
8687
frequency_penalty: Float between -2.0 and 2.0. Positive values penalize new
8788
tokens based on their existing frequency in the text so far, decreasing
8889
the model's likelihood of repeating the same line verbatim.
90+
token_buffer: Number of tokens below the LLM's limit to generate. In case
91+
our tokenizer does not exactly match the LLM API service's perceived
92+
number of tokens, this prevents service errors. On the other hand, this
93+
may lead to generating fewer tokens in the completion than is actually
94+
possible.
8995
9096
Returns:
9197
An OpenAI-like response object if there were no errors in generation.
9298
In case of API-specific error, Exception object is captured and returned.
9399
"""
94100
num_prompt_tokens = count_tokens_from_string(prompt)
95101
if self.max_tokens:
96-
max_tokens = self.max_tokens - num_prompt_tokens
102+
max_tokens = self.max_tokens - num_prompt_tokens - token_buffer
97103
else:
98-
max_tokens = 4 * num_prompt_tokens
104+
max_tokens = 3 * num_prompt_tokens
99105

100106
response = completion( # completion gets the key from os.getenv
101107
model=self.model_name,
@@ -116,6 +122,7 @@ async def generate_batch_completion(
116122
temperature: float = 1,
117123
responses_per_request: int = 5,
118124
requests_per_minute: int = 80,
125+
token_buffer: int = 300,
119126
) -> list[openai.Completion]:
120127
"""Generate a batch responses from OpenAI Chat Completion API.
121128
@@ -126,6 +133,11 @@ async def generate_batch_completion(
126133
responses_per_request: Number of responses for each request.
127134
i.e. the parameter n of API call.
128135
requests_per_minute: Number of requests per minute to allow.
136+
token_buffer: Number of tokens below the LLM's limit to generate. In case
137+
our tokenizer does not exactly match the LLM API service's perceived
138+
number of tokens, this prevents service errors. On the other hand, this
139+
may lead to generating fewer tokens in the completion than is actually
140+
possible.
129141
130142
Returns:
131143
List of generated responses.
@@ -183,9 +195,9 @@ async def _throttled_completion_acreate(
183195

184196
num_prompt_tokens = max(count_tokens_from_string(prompt) for prompt in prompts)
185197
if self.max_tokens:
186-
max_tokens = self.max_tokens - num_prompt_tokens
198+
max_tokens = self.max_tokens - num_prompt_tokens - token_buffer
187199
else:
188-
max_tokens = 4 * num_prompt_tokens
200+
max_tokens = 3 * num_prompt_tokens
189201

190202
async_responses = [
191203
_throttled_completion_acreate(

test_helpers/mock_api.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,7 @@ def generate_one_completion(
196196
temperature: float = 0,
197197
presence_penalty: float = 0,
198198
frequency_penalty: float = 0,
199+
token_buffer: int = 300,
199200
) -> openai.Completion:
200201
"""Return a mocked object and increment the counter."""
201202
self.generate_one_call_counter += 1
@@ -207,6 +208,7 @@ async def generate_batch_completion(
207208
temperature: float = 1,
208209
responses_per_request: int = 5,
209210
requests_per_minute: int = 80,
211+
token_buffer: int = 300,
210212
) -> list[openai.Completion]:
211213
"""Return a mocked object and increment the counter."""
212214
self.generate_batch_call_counter += 1

0 commit comments

Comments
 (0)