Skip to content

Commit

Permalink
Openai error handling for ratelimit, timeout and try again errors(fix #…
Browse files Browse the repository at this point in the history
…1255) (#1361)

* Adds error handling for openai's rate limit error in llms/openai module and its tests

- Adds test for rate limit error handling in the llms/openai module
- Adds error handling for rate limit error in the llms/openai module
- Refactors code in llms/openai module to be readable and modular

* Adds error handling for openai's timeout error in llms/openai module and its test

- Adds test for timeout error handling in chat_completion in llms/openai module
- Adds error handling for openai's timeout error in chat_completion in llms/openai module

* Adds error handling for openai's try again error in llms/openai module and its test

- Adds test for openai's try again error handling in chat_completion in llms/openai module
- Adds error handling for openai's try again error in chat_completion in llms/openai module

* Refactors llms/openai module and its tests to return error after retry attempts are exausted

* Increases wait time for retry of chat_completion in llms/openai module

* Removes unused import
  • Loading branch information
aleric-cusher authored Dec 13, 2023
1 parent ea2a0b6 commit 240d05d
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 5 deletions.
34 changes: 30 additions & 4 deletions superagi/llms/openai.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
import openai
from openai import APIError, InvalidRequestError
from openai.error import RateLimitError, AuthenticationError
from openai.error import RateLimitError, AuthenticationError, Timeout, TryAgain
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_random_exponential

from superagi.config.config import get_config
from superagi.lib.logger import logger
from superagi.llms.base_llm import BaseLlm

MAX_RETRY_ATTEMPTS = 5
MIN_WAIT = 30 # Seconds
MAX_WAIT = 300 # Seconds

def custom_retry_error_callback(retry_state):
logger.info("OpenAi Exception:", retry_state.outcome.exception())
return {"error": "ERROR_OPENAI", "message": "Open ai exception: "+str(retry_state.outcome.exception())}


class OpenAi(BaseLlm):
def __init__(self, api_key, model="gpt-4", temperature=0.6, max_tokens=get_config("MAX_MODEL_TOKEN_LIMIT"), top_p=1,
Expand Down Expand Up @@ -50,6 +59,17 @@ def get_model(self):
"""
return self.model

@retry(
retry=(
retry_if_exception_type(RateLimitError) |
retry_if_exception_type(Timeout) |
retry_if_exception_type(TryAgain)
),
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS), # Maximum number of retry attempts
wait=wait_random_exponential(min=MIN_WAIT, max=MAX_WAIT),
before_sleep=lambda retry_state: logger.info(f"{retry_state.outcome.exception()} (attempt {retry_state.attempt_number})"),
retry_error_callback=custom_retry_error_callback
)
def chat_completion(self, messages, max_tokens=get_config("MAX_MODEL_TOKEN_LIMIT")):
"""
Call the OpenAI chat completion API.
Expand All @@ -75,12 +95,18 @@ def chat_completion(self, messages, max_tokens=get_config("MAX_MODEL_TOKEN_LIMIT
)
content = response.choices[0].message["content"]
return {"response": response, "content": content}
except RateLimitError as api_error:
logger.info("OpenAi RateLimitError:", api_error)
raise RateLimitError(str(api_error))
except Timeout as timeout_error:
logger.info("OpenAi Timeout:", timeout_error)
raise Timeout(str(timeout_error))
except TryAgain as try_again_error:
logger.info("OpenAi TryAgain:", try_again_error)
raise TryAgain(str(try_again_error))
except AuthenticationError as auth_error:
logger.info("OpenAi AuthenticationError:", auth_error)
return {"error": "ERROR_AUTHENTICATION", "message": "Authentication error please check the api keys: "+str(auth_error)}
except RateLimitError as api_error:
logger.info("OpenAi RateLimitError:", api_error)
return {"error": "ERROR_RATE_LIMIT", "message": "Openai rate limit exceeded: "+str(api_error)}
except InvalidRequestError as invalid_request_error:
logger.info("OpenAi InvalidRequestError:", invalid_request_error)
return {"error": "ERROR_INVALID_REQUEST", "message": "Openai invalid request error: "+str(invalid_request_error)}
Expand Down
76 changes: 75 additions & 1 deletion tests/unit_tests/llms/test_open_ai.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import openai
import pytest
from unittest.mock import MagicMock, patch
from superagi.llms.openai import OpenAi

from superagi.llms.openai import OpenAi, MAX_RETRY_ATTEMPTS


@patch('superagi.llms.openai.openai')
Expand Down Expand Up @@ -33,6 +35,78 @@ def test_chat_completion(mock_openai):
)


@patch('superagi.llms.openai.wait_random_exponential.__call__')
@patch('superagi.llms.openai.openai')
def test_chat_completion_retry_rate_limit_error(mock_openai, mock_wait_random_exponential):
# Arrange
model = 'gpt-4'
api_key = 'test_key'
openai_instance = OpenAi(api_key, model=model)

messages = [{"role": "system", "content": "You are a helpful assistant."}]
max_tokens = 100

mock_openai.ChatCompletion.create.side_effect = openai.error.RateLimitError("Rate limit exceeded")

# Mock sleep time
mock_wait_random_exponential.return_value = 0.1

# Act
result = openai_instance.chat_completion(messages, max_tokens)

# Assert
assert result == {"error": "ERROR_OPENAI", "message": "Open ai exception: Rate limit exceeded"}
assert mock_openai.ChatCompletion.create.call_count == MAX_RETRY_ATTEMPTS


@patch('superagi.llms.openai.wait_random_exponential.__call__')
@patch('superagi.llms.openai.openai')
def test_chat_completion_retry_timeout_error(mock_openai, mock_wait_random_exponential):
# Arrange
model = 'gpt-4'
api_key = 'test_key'
openai_instance = OpenAi(api_key, model=model)

messages = [{"role": "system", "content": "You are a helpful assistant."}]
max_tokens = 100

mock_openai.ChatCompletion.create.side_effect = openai.error.Timeout("Timeout occured")

# Mock sleep time
mock_wait_random_exponential.return_value = 0.1

# Act
result = openai_instance.chat_completion(messages, max_tokens)

# Assert
assert result == {"error": "ERROR_OPENAI", "message": "Open ai exception: Timeout occured"}
assert mock_openai.ChatCompletion.create.call_count == MAX_RETRY_ATTEMPTS


@patch('superagi.llms.openai.wait_random_exponential.__call__')
@patch('superagi.llms.openai.openai')
def test_chat_completion_retry_try_again_error(mock_openai, mock_wait_random_exponential):
# Arrange
model = 'gpt-4'
api_key = 'test_key'
openai_instance = OpenAi(api_key, model=model)

messages = [{"role": "system", "content": "You are a helpful assistant."}]
max_tokens = 100

mock_openai.ChatCompletion.create.side_effect = openai.error.TryAgain("Try Again")

# Mock sleep time
mock_wait_random_exponential.return_value = 0.1

# Act
result = openai_instance.chat_completion(messages, max_tokens)

# Assert
assert result == {"error": "ERROR_OPENAI", "message": "Open ai exception: Try Again"}
assert mock_openai.ChatCompletion.create.call_count == MAX_RETRY_ATTEMPTS


def test_verify_access_key():
model = 'gpt-4'
api_key = 'test_key'
Expand Down

0 comments on commit 240d05d

Please sign in to comment.