diff --git a/superagi/llms/openai.py b/superagi/llms/openai.py index 82ddb34f5..c7cef15fd 100644 --- a/superagi/llms/openai.py +++ b/superagi/llms/openai.py @@ -11,6 +11,10 @@ MIN_WAIT = 30 # Seconds MAX_WAIT = 120 # Seconds +def custom_retry_error_callback(retry_state): + logger.info("OpenAi Exception:", retry_state.outcome.exception()) + return {"error": "ERROR_OPENAI", "message": "Open ai exception: "+str(retry_state.outcome.exception())} + class OpenAi(BaseLlm): def __init__(self, api_key, model="gpt-4", temperature=0.6, max_tokens=get_config("MAX_MODEL_TOKEN_LIMIT"), top_p=1, @@ -64,6 +68,7 @@ def get_model(self): stop=stop_after_attempt(MAX_RETRY_ATTEMPTS), # Maximum number of retry attempts wait=wait_random_exponential(min=MIN_WAIT, max=MAX_WAIT), before_sleep=lambda retry_state: logger.info(f"{retry_state.outcome.exception()} (attempt {retry_state.attempt_number})"), + retry_error_callback=custom_retry_error_callback ) def chat_completion(self, messages, max_tokens=get_config("MAX_MODEL_TOKEN_LIMIT")): """ diff --git a/tests/unit_tests/llms/test_open_ai.py b/tests/unit_tests/llms/test_open_ai.py index 5ab67bef8..83343d9d3 100644 --- a/tests/unit_tests/llms/test_open_ai.py +++ b/tests/unit_tests/llms/test_open_ai.py @@ -53,10 +53,10 @@ def test_chat_completion_retry_rate_limit_error(mock_openai, mock_wait_random_ex mock_wait_random_exponential.return_value = 0.1 # Act - with pytest.raises(tenacity.RetryError): - result = openai_instance.chat_completion(messages, max_tokens) + result = openai_instance.chat_completion(messages, max_tokens) # Assert + assert result == {"error": "ERROR_OPENAI", "message": "Open ai exception: Rate limit exceeded"} assert mock_openai.ChatCompletion.create.call_count == MAX_RETRY_ATTEMPTS @@ -77,10 +77,10 @@ def test_chat_completion_retry_timeout_error(mock_openai, mock_wait_random_expon mock_wait_random_exponential.return_value = 0.1 # Act - with pytest.raises(tenacity.RetryError): - result = openai_instance.chat_completion(messages, max_tokens) + result = openai_instance.chat_completion(messages, max_tokens) # Assert + assert result == {"error": "ERROR_OPENAI", "message": "Open ai exception: Timeout occured"} assert mock_openai.ChatCompletion.create.call_count == MAX_RETRY_ATTEMPTS @@ -101,10 +101,10 @@ def test_chat_completion_retry_try_again_error(mock_openai, mock_wait_random_exp mock_wait_random_exponential.return_value = 0.1 # Act - with pytest.raises(tenacity.RetryError): - result = openai_instance.chat_completion(messages, max_tokens) + result = openai_instance.chat_completion(messages, max_tokens) # Assert + assert result == {"error": "ERROR_OPENAI", "message": "Open ai exception: Try Again"} assert mock_openai.ChatCompletion.create.call_count == MAX_RETRY_ATTEMPTS