Skip to content

Commit

Permalink
add client session fix
Browse files Browse the repository at this point in the history
  • Loading branch information
saum7800 committed Apr 18, 2024
1 parent 04debc9 commit 89fcc46
Showing 1 changed file with 69 additions and 68 deletions.
137 changes: 69 additions & 68 deletions prompt2model/utils/api_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,80 +141,81 @@ async def generate_batch_completion(
Returns:
List of generated responses.
"""
openai.aiosession.set(ClientSession())
limiter = aiolimiter.AsyncLimiter(requests_per_minute)
async with ClientSession() as _:
limiter = aiolimiter.AsyncLimiter(requests_per_minute)

async def _throttled_completion_acreate(
model: str,
messages: list[dict[str, str]],
temperature: float,
max_tokens: int,
n: int,
top_p: float,
limiter: aiolimiter.AsyncLimiter,
):
async with limiter:
for _ in range(3):
try:
return await acompletion(
model=model,
messages=messages,
api_base=self.api_base,
temperature=temperature,
max_tokens=max_tokens,
n=n,
top_p=top_p,
)
except tuple(ERROR_ERRORS_TO_MESSAGES.keys()) as e:
if isinstance(
e,
(
openai.APIStatusError,
openai.APIError,
),
):
logging.warning(
ERROR_ERRORS_TO_MESSAGES[type(e)].format(e=e)
async def _throttled_completion_acreate(
model: str,
messages: list[dict[str, str]],
temperature: float,
max_tokens: int,
n: int,
top_p: float,
limiter: aiolimiter.AsyncLimiter,
):
async with limiter:
for _ in range(3):
try:
return await acompletion(
model=model,
messages=messages,
api_base=self.api_base,
temperature=temperature,
max_tokens=max_tokens,
n=n,
top_p=top_p,
)
elif isinstance(e, openai.BadRequestError):
logging.warning(ERROR_ERRORS_TO_MESSAGES[type(e)])
return {
"choices": [
{
"message": {
"content": "Invalid Request: Prompt was filtered" # noqa E501
except tuple(ERROR_ERRORS_TO_MESSAGES.keys()) as e:
if isinstance(
e,
(
openai.APIStatusError,
openai.APIError,
),
):
logging.warning(
ERROR_ERRORS_TO_MESSAGES[type(e)].format(e=e)
)
elif isinstance(e, openai.BadRequestError):
logging.warning(ERROR_ERRORS_TO_MESSAGES[type(e)])
return {
"choices": [
{
"message": {
"content": "Invalid Request: Prompt was filtered" # noqa E501
}
}
}
]
}
else:
logging.warning(ERROR_ERRORS_TO_MESSAGES[type(e)])
await asyncio.sleep(10)
return {"choices": [{"message": {"content": ""}}]}
]
}
else:
logging.warning(ERROR_ERRORS_TO_MESSAGES[type(e)])
await asyncio.sleep(10)
return {"choices": [{"message": {"content": ""}}]}

num_prompt_tokens = max(count_tokens_from_string(prompt) for prompt in prompts)
if self.max_tokens:
max_tokens = self.max_tokens - num_prompt_tokens - token_buffer
else:
max_tokens = 3 * num_prompt_tokens

async_responses = [
_throttled_completion_acreate(
model=self.model_name,
messages=[
{"role": "user", "content": f"{prompt}"},
],
temperature=temperature,
max_tokens=max_tokens,
n=responses_per_request,
top_p=1,
limiter=limiter,
num_prompt_tokens = max(
count_tokens_from_string(prompt) for prompt in prompts
)
for prompt in prompts
]
responses = await tqdm_asyncio.gather(*async_responses)
if self.max_tokens:
max_tokens = self.max_tokens - num_prompt_tokens - token_buffer
else:
max_tokens = 3 * num_prompt_tokens

async_responses = [
_throttled_completion_acreate(
model=self.model_name,
messages=[
{"role": "user", "content": f"{prompt}"},
],
temperature=temperature,
max_tokens=max_tokens,
n=responses_per_request,
top_p=1,
limiter=limiter,
)
for prompt in prompts
]
responses = await tqdm_asyncio.gather(*async_responses)
# Note: will never be none because it's set, but mypy doesn't know that.
await openai.aiosession.get().close()
return responses


Expand Down

0 comments on commit 89fcc46

Please sign in to comment.