diff --git a/superagi/llms/openai.py b/superagi/llms/openai.py index 454f7ce19..62ace6305 100644 --- a/superagi/llms/openai.py +++ b/superagi/llms/openai.py @@ -1,11 +1,20 @@ import openai from openai import APIError, InvalidRequestError -from openai.error import RateLimitError, AuthenticationError +from openai.error import RateLimitError, AuthenticationError, Timeout, TryAgain +from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_random_exponential from superagi.config.config import get_config from superagi.lib.logger import logger from superagi.llms.base_llm import BaseLlm +MAX_RETRY_ATTEMPTS = 5 +MIN_WAIT = 30 # Seconds +MAX_WAIT = 300 # Seconds + +def custom_retry_error_callback(retry_state): + logger.info("OpenAi Exception:", retry_state.outcome.exception()) + return {"error": "ERROR_OPENAI", "message": "Open ai exception: "+str(retry_state.outcome.exception())} + class OpenAi(BaseLlm): def __init__(self, api_key, model="gpt-4", temperature=0.6, max_tokens=get_config("MAX_MODEL_TOKEN_LIMIT"), top_p=1, @@ -50,6 +59,17 @@ def get_model(self): """ return self.model + @retry( + retry=( + retry_if_exception_type(RateLimitError) | + retry_if_exception_type(Timeout) | + retry_if_exception_type(TryAgain) + ), + stop=stop_after_attempt(MAX_RETRY_ATTEMPTS), # Maximum number of retry attempts + wait=wait_random_exponential(min=MIN_WAIT, max=MAX_WAIT), + before_sleep=lambda retry_state: logger.info(f"{retry_state.outcome.exception()} (attempt {retry_state.attempt_number})"), + retry_error_callback=custom_retry_error_callback + ) def chat_completion(self, messages, max_tokens=get_config("MAX_MODEL_TOKEN_LIMIT")): """ Call the OpenAI chat completion API. @@ -75,12 +95,18 @@ def chat_completion(self, messages, max_tokens=get_config("MAX_MODEL_TOKEN_LIMIT ) content = response.choices[0].message["content"] return {"response": response, "content": content} + except RateLimitError as api_error: + logger.info("OpenAi RateLimitError:", api_error) + raise RateLimitError(str(api_error)) + except Timeout as timeout_error: + logger.info("OpenAi Timeout:", timeout_error) + raise Timeout(str(timeout_error)) + except TryAgain as try_again_error: + logger.info("OpenAi TryAgain:", try_again_error) + raise TryAgain(str(try_again_error)) except AuthenticationError as auth_error: logger.info("OpenAi AuthenticationError:", auth_error) return {"error": "ERROR_AUTHENTICATION", "message": "Authentication error please check the api keys: "+str(auth_error)} - except RateLimitError as api_error: - logger.info("OpenAi RateLimitError:", api_error) - return {"error": "ERROR_RATE_LIMIT", "message": "Openai rate limit exceeded: "+str(api_error)} except InvalidRequestError as invalid_request_error: logger.info("OpenAi InvalidRequestError:", invalid_request_error) return {"error": "ERROR_INVALID_REQUEST", "message": "Openai invalid request error: "+str(invalid_request_error)} diff --git a/tests/unit_tests/llms/test_open_ai.py b/tests/unit_tests/llms/test_open_ai.py index 9882092f4..31b4c576d 100644 --- a/tests/unit_tests/llms/test_open_ai.py +++ b/tests/unit_tests/llms/test_open_ai.py @@ -1,6 +1,8 @@ +import openai import pytest from unittest.mock import MagicMock, patch -from superagi.llms.openai import OpenAi + +from superagi.llms.openai import OpenAi, MAX_RETRY_ATTEMPTS @patch('superagi.llms.openai.openai') @@ -33,6 +35,78 @@ def test_chat_completion(mock_openai): ) +@patch('superagi.llms.openai.wait_random_exponential.__call__') +@patch('superagi.llms.openai.openai') +def test_chat_completion_retry_rate_limit_error(mock_openai, mock_wait_random_exponential): + # Arrange + model = 'gpt-4' + api_key = 'test_key' + openai_instance = OpenAi(api_key, model=model) + + messages = [{"role": "system", "content": "You are a helpful assistant."}] + max_tokens = 100 + + mock_openai.ChatCompletion.create.side_effect = openai.error.RateLimitError("Rate limit exceeded") + + # Mock sleep time + mock_wait_random_exponential.return_value = 0.1 + + # Act + result = openai_instance.chat_completion(messages, max_tokens) + + # Assert + assert result == {"error": "ERROR_OPENAI", "message": "Open ai exception: Rate limit exceeded"} + assert mock_openai.ChatCompletion.create.call_count == MAX_RETRY_ATTEMPTS + + +@patch('superagi.llms.openai.wait_random_exponential.__call__') +@patch('superagi.llms.openai.openai') +def test_chat_completion_retry_timeout_error(mock_openai, mock_wait_random_exponential): + # Arrange + model = 'gpt-4' + api_key = 'test_key' + openai_instance = OpenAi(api_key, model=model) + + messages = [{"role": "system", "content": "You are a helpful assistant."}] + max_tokens = 100 + + mock_openai.ChatCompletion.create.side_effect = openai.error.Timeout("Timeout occured") + + # Mock sleep time + mock_wait_random_exponential.return_value = 0.1 + + # Act + result = openai_instance.chat_completion(messages, max_tokens) + + # Assert + assert result == {"error": "ERROR_OPENAI", "message": "Open ai exception: Timeout occured"} + assert mock_openai.ChatCompletion.create.call_count == MAX_RETRY_ATTEMPTS + + +@patch('superagi.llms.openai.wait_random_exponential.__call__') +@patch('superagi.llms.openai.openai') +def test_chat_completion_retry_try_again_error(mock_openai, mock_wait_random_exponential): + # Arrange + model = 'gpt-4' + api_key = 'test_key' + openai_instance = OpenAi(api_key, model=model) + + messages = [{"role": "system", "content": "You are a helpful assistant."}] + max_tokens = 100 + + mock_openai.ChatCompletion.create.side_effect = openai.error.TryAgain("Try Again") + + # Mock sleep time + mock_wait_random_exponential.return_value = 0.1 + + # Act + result = openai_instance.chat_completion(messages, max_tokens) + + # Assert + assert result == {"error": "ERROR_OPENAI", "message": "Open ai exception: Try Again"} + assert mock_openai.ChatCompletion.create.call_count == MAX_RETRY_ATTEMPTS + + def test_verify_access_key(): model = 'gpt-4' api_key = 'test_key'