| 
3 | 3 | import unittest  | 
4 | 4 | 
 
  | 
5 | 5 | from gradio_client import Client  | 
6 |  | -from unittest.mock import patch, MagicMock  | 
 | 6 | +from unittest.mock import patch, MagicMock, Mock  | 
7 | 7 | from langchain.schema import HumanMessage, AIMessage, SystemMessage  | 
8 |  | -from app import build_chat_context, inference, PossibleSystemPromptException, BACKEND_INITIALISED  | 
 | 8 | +from app import build_chat_context, inference, PossibleSystemPromptException, gr  | 
9 | 9 | 
 
  | 
10 | 10 | url = os.environ.get("GRADIO_URL", "http://localhost:7860")  | 
11 | 11 | client = Client(url)  | 
@@ -95,6 +95,111 @@ def test_inference_thinking_tags(self, mock_build_chat_context, mock_llm):  | 
95 | 95 | 
 
  | 
96 | 96 |         self.assertEqual(responses, ["Thinking...", "Thinking...", "", "final response"])  | 
97 | 97 | 
 
  | 
 | 98 | +    @patch("app.llm")  | 
 | 99 | +    @patch("app.INCLUDE_SYSTEM_PROMPT", True)  | 
 | 100 | +    @patch("app.build_chat_context")  | 
 | 101 | +    def test_inference_PossibleSystemPromptException(self, mock_build_chat_context, mock_llm):  | 
 | 102 | +        mock_build_chat_context.return_value = ["mock_context"]  | 
 | 103 | +        mock_response = Mock()  | 
 | 104 | +        mock_response.json.return_value = {"message": "Bad request"}  | 
 | 105 | + | 
 | 106 | +        mock_llm.stream.side_effect = openai.BadRequestError(  | 
 | 107 | +            message="Bad request",  | 
 | 108 | +            response=mock_response,  | 
 | 109 | +            body=None  | 
 | 110 | +        )  | 
 | 111 | + | 
 | 112 | +        latest_message = "Hello"  | 
 | 113 | +        history = []  | 
 | 114 | + | 
 | 115 | +        with self.assertRaises(PossibleSystemPromptException):  | 
 | 116 | +            list(inference(latest_message, history))  | 
 | 117 | + | 
 | 118 | +    @patch("app.llm")  | 
 | 119 | +    @patch("app.INCLUDE_SYSTEM_PROMPT", False)  | 
 | 120 | +    @patch("app.build_chat_context")  | 
 | 121 | +    def test_inference_general_error(self, mock_build_chat_context, mock_llm):  | 
 | 122 | +        mock_build_chat_context.return_value = ["mock_context"]  | 
 | 123 | +        mock_response = Mock()  | 
 | 124 | +        mock_response.json.return_value = {"message": "Bad request"}  | 
 | 125 | + | 
 | 126 | +        mock_llm.stream.side_effect = openai.BadRequestError(  | 
 | 127 | +            message="Bad request",  | 
 | 128 | +            response=mock_response,  | 
 | 129 | +            body=None  | 
 | 130 | +        )  | 
 | 131 | + | 
 | 132 | +        latest_message = "Hello"  | 
 | 133 | +        history = []  | 
 | 134 | +        exception_message = "\'API Error received. This usually means the chosen LLM uses an incompatible prompt format. Error message was: Bad request\'"  | 
 | 135 | + | 
 | 136 | +        with self.assertRaises(gr.Error) as gradio_error:  | 
 | 137 | +            list(inference(latest_message, history))  | 
 | 138 | +        self.assertEqual(str(gradio_error.exception), exception_message)  | 
 | 139 | + | 
 | 140 | +    @patch("app.llm")  | 
 | 141 | +    @patch("app.build_chat_context")  | 
 | 142 | +    @patch("app.log")  | 
 | 143 | +    @patch("app.gr")  | 
 | 144 | +    @patch("app.BACKEND_INITIALISED", False)  | 
 | 145 | +    def test_inference_APIConnectionError(self, mock_gr, mock_logger, mock_build_chat_context, mock_llm):  | 
 | 146 | +        mock_build_chat_context.return_value = ["mock_context"]  | 
 | 147 | +        mock_request = Mock()  | 
 | 148 | +        mock_request.json.return_value = {"message": "Foo"}  | 
 | 149 | + | 
 | 150 | +        mock_llm.stream.side_effect = openai.APIConnectionError(  | 
 | 151 | +            message="Foo",  | 
 | 152 | +            request=mock_request,  | 
 | 153 | +        )  | 
 | 154 | + | 
 | 155 | +        latest_message = "Hello"  | 
 | 156 | +        history = []  | 
 | 157 | + | 
 | 158 | +        list(inference(latest_message, history))  | 
 | 159 | +        mock_logger.info.assert_any_call("Backend API not yet ready")  | 
 | 160 | +        mock_gr.Info.assert_any_call("Backend not ready - model may still be initialising - please try again later.")  | 
 | 161 | + | 
 | 162 | +    @patch("app.llm")  | 
 | 163 | +    @patch("app.build_chat_context")  | 
 | 164 | +    @patch("app.log")  | 
 | 165 | +    @patch("app.gr")  | 
 | 166 | +    @patch("app.BACKEND_INITIALISED", True)  | 
 | 167 | +    def test_inference_APIConnectionError_initialised(self, mock_gr, mock_logger, mock_build_chat_context, mock_llm):  | 
 | 168 | +        mock_build_chat_context.return_value = ["mock_context"]  | 
 | 169 | +        mock_request = Mock()  | 
 | 170 | +        mock_request.json.return_value = {"message": "Foo"}  | 
 | 171 | + | 
 | 172 | +        mock_llm.stream.side_effect = openai.APIConnectionError(  | 
 | 173 | +            message="Foo",  | 
 | 174 | +            request=mock_request,  | 
 | 175 | +        )  | 
 | 176 | + | 
 | 177 | +        latest_message = "Hello"  | 
 | 178 | +        history = []  | 
 | 179 | + | 
 | 180 | +        list(inference(latest_message, history))  | 
 | 181 | +        mock_logger.error.assert_called_once_with("Failed to connect to backend API: %s", mock_llm.stream.side_effect)  | 
 | 182 | +        mock_gr.Warning.assert_any_call("Failed to connect to backend API.")  | 
 | 183 | + | 
 | 184 | +    @patch("app.llm")  | 
 | 185 | +    @patch("app.build_chat_context")  | 
 | 186 | +    @patch("app.gr")  | 
 | 187 | +    def test_inference_InternalServerError(self, mock_gr, mock_build_chat_context, mock_llm):  | 
 | 188 | +        mock_build_chat_context.return_value = ["mock_context"]  | 
 | 189 | +        mock_request = Mock()  | 
 | 190 | +        mock_request.json.return_value = {"message": "Foo"}  | 
 | 191 | + | 
 | 192 | +        mock_llm.stream.side_effect = openai.InternalServerError(  | 
 | 193 | +            message="Foo",  | 
 | 194 | +            response=mock_request,  | 
 | 195 | +            body=None  | 
 | 196 | +        )  | 
 | 197 | + | 
 | 198 | +        latest_message = "Hello"  | 
 | 199 | +        history = []  | 
 | 200 | + | 
 | 201 | +        list(inference(latest_message, history))  | 
 | 202 | +        mock_gr.Warning.assert_any_call("Internal server error encountered in backend API - see API logs for details.")  | 
98 | 203 | 
 
  | 
99 | 204 | if __name__ == "__main__":  | 
100 | 205 |     unittest.main()  | 
0 commit comments