From 02237496a1903cf02f9a328e73892394d4dfc5c7 Mon Sep 17 00:00:00 2001 From: Daniel McKnight <34697904+NeonDaniel@users.noreply.github.com> Date: Tue, 19 Nov 2024 12:53:48 -0800 Subject: [PATCH] Implement FastAPI Unit Tests (#34) # Description Adds test coverage for all exposed endpoints with mocked backend methods Addresses some edge-case bugs in `request.client` object handling Fixes/completes model type annotations # Issues # Other Notes --- neon_hana/app/routers/api_proxy.py | 2 +- neon_hana/app/routers/assist.py | 3 +- neon_hana/app/routers/auth.py | 3 +- neon_hana/app/routers/util.py | 23 +- neon_hana/auth/client_manager.py | 3 +- neon_hana/mq_service_api.py | 2 +- neon_hana/schema/api_responses.py | 7 +- neon_hana/schema/llm_requests.py | 6 +- requirements/test_requirements.txt | 1 + tests/test_app.py | 557 +++++++++++++++++++++++++++++ 10 files changed, 594 insertions(+), 13 deletions(-) create mode 100644 tests/test_app.py diff --git a/neon_hana/app/routers/api_proxy.py b/neon_hana/app/routers/api_proxy.py index 25135f1..fdb8fbb 100644 --- a/neon_hana/app/routers/api_proxy.py +++ b/neon_hana/app/routers/api_proxy.py @@ -66,4 +66,4 @@ async def api_proxy_geolocation(query: GeoAPIReverseRequest) -> GeoAPIReverseRes @proxy_route.post("/wolframalpha") async def api_proxy_wolframalpha(query: WolframAlphaAPIRequest) -> WolframAlphaAPIResponse: - return mq_connector.query_api_proxy("wolfram_alpha", dict(query)) \ No newline at end of file + return mq_connector.query_api_proxy("wolfram_alpha", dict(query)) diff --git a/neon_hana/app/routers/assist.py b/neon_hana/app/routers/assist.py index 52d88f0..8bcc432 100644 --- a/neon_hana/app/routers/assist.py +++ b/neon_hana/app/routers/assist.py @@ -47,5 +47,6 @@ async def get_tts(request: TTSRequest) -> TTSResponse: async def get_response(skill_request: SkillRequest, request: Request) -> SkillResponse: if not skill_request.node_data.networking.public_ip: - skill_request.node_data.networking.public_ip = request.client.host + host = request.client.host if request.client else "" + skill_request.node_data.networking.public_ip = host return mq_connector.get_response(**dict(skill_request)) diff --git a/neon_hana/app/routers/auth.py b/neon_hana/app/routers/auth.py index 4cf78e2..14a9359 100644 --- a/neon_hana/app/routers/auth.py +++ b/neon_hana/app/routers/auth.py @@ -35,8 +35,9 @@ @auth_route.post("/login") async def check_login(auth_request: AuthenticationRequest, request: Request) -> AuthenticationResponse: + ip_addr = request.client.host if request.client else "127.0.0.1" return client_manager.check_auth_request(**dict(auth_request), - origin_ip=request.client.host) + origin_ip=ip_addr) @auth_route.post("/refresh") diff --git a/neon_hana/app/routers/util.py b/neon_hana/app/routers/util.py index a4bf674..2d62d94 100644 --- a/neon_hana/app/routers/util.py +++ b/neon_hana/app/routers/util.py @@ -24,6 +24,8 @@ # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import re + from fastapi import APIRouter, Request from starlette.responses import PlainTextResponse @@ -32,13 +34,28 @@ util_route = APIRouter(prefix="/util", tags=["utilities"]) +def _is_ipv4(address: str) -> bool: + ipv4_regex = re.compile( + r'^(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.(?:25[0-5]|2[0-4][0-9]|[01' + r']?[0-9][0-9]?)\.(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.(?:25[0-5]|' + r'2[0-4][0-9]|[01]?[0-9][0-9]?)$') + return ipv4_regex.match(address) + + @util_route.get("/client_ip", response_class=PlainTextResponse) async def api_client_ip(request: Request) -> str: - client_manager.validate_auth("", request.client.host) - return request.client.host + ip_addr = request.client.host if request.client else "127.0.0.1" + + if not _is_ipv4(ip_addr): + # Reported host is a hostname, not an IP address. Return a generic + # loopback value + ip_addr = "127.0.0.1" + client_manager.validate_auth("", ip_addr) + return ip_addr @util_route.get("/headers") async def api_headers(request: Request): - client_manager.validate_auth("", request.client.host) + ip_addr = request.client.host if request.client else "127.0.0.1" + client_manager.validate_auth("", ip_addr) return request.headers diff --git a/neon_hana/auth/client_manager.py b/neon_hana/auth/client_manager.py index 578ea13..7ad2e26 100644 --- a/neon_hana/auth/client_manager.py +++ b/neon_hana/auth/client_manager.py @@ -235,8 +235,9 @@ async def __call__(self, request: Request): if not credentials.scheme == "Bearer": raise HTTPException(status_code=403, detail="Invalid authentication scheme.") + host = request.client.host if request.client else "127.0.0.1" if not self.client_manager.validate_auth(credentials.credentials, - request.client.host): + host): raise HTTPException(status_code=403, detail="Invalid or expired token.") return credentials.credentials diff --git a/neon_hana/mq_service_api.py b/neon_hana/mq_service_api.py index 032e800..a36e5e3 100644 --- a/neon_hana/mq_service_api.py +++ b/neon_hana/mq_service_api.py @@ -57,7 +57,7 @@ def _validate_api_proxy_response(response: dict, query_params: dict): try: resp = json.loads(response['content']) if query_params.get('service') == "alpha_vantage": - resp['service'] = query_params['service'] + resp['provider'] = query_params['service'] if query_params.get("region") and resp.get('bestMatches'): filtered = [ stock for stock in resp.get("bestMatches") diff --git a/neon_hana/schema/api_responses.py b/neon_hana/schema/api_responses.py index 848955b..5f0e77b 100644 --- a/neon_hana/schema/api_responses.py +++ b/neon_hana/schema/api_responses.py @@ -34,7 +34,7 @@ class WeatherAPIOnecallResponse(BaseModel): timezone: str timezone_offset: int current: Dict[str, Any] - minutely: Optional[List[dict]] + minutely: Optional[List[dict]] = None hourly: List[dict] daily: List[dict] @@ -1742,7 +1742,8 @@ class WeatherAPIOnecallResponse(BaseModel): class StockAPIQuoteResponse(BaseModel): - global_quote: Dict[str, str] = Field(..., alias="Global Quote") + provider: str + global_quote: Dict[str, str] = Field(alias="Global Quote") model_config = { "extra": "allow", @@ -1767,6 +1768,8 @@ class StockAPIQuoteResponse(BaseModel): class StockAPISearchResponse(BaseModel): + provider: str + best_matches: List[Dict[str, str]] = Field(alias="bestMatches") model_config = { "extra": "allow", "json_schema_extra": { diff --git a/neon_hana/schema/llm_requests.py b/neon_hana/schema/llm_requests.py index 1861f4d..4085a69 100644 --- a/neon_hana/schema/llm_requests.py +++ b/neon_hana/schema/llm_requests.py @@ -24,7 +24,7 @@ # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -from typing import List +from typing import List, Tuple from pydantic import BaseModel @@ -42,11 +42,11 @@ class LLMRequest(BaseModel): class LLMResponse(BaseModel): response: str - history: List[tuple] + history: List[Tuple[str, str]] model_config = { "json_schema_extra": { "examples": [{ - "query": "I am well, how about you?", + "response": "As a large language model, I do not feel", "history": [("user", "hello"), ("llm", "Hi, how can I help you today?"), ("user", "I am well, how about you?"), diff --git a/requirements/test_requirements.txt b/requirements/test_requirements.txt index 98ff01c..df9612e 100644 --- a/requirements/test_requirements.txt +++ b/requirements/test_requirements.txt @@ -1,4 +1,5 @@ pytest mock +httpx neon-iris~=0.1 websockets~=12.0 \ No newline at end of file diff --git a/tests/test_app.py b/tests/test_app.py new file mode 100644 index 0000000..ef546b1 --- /dev/null +++ b/tests/test_app.py @@ -0,0 +1,557 @@ +import json +from time import time +from unittest import TestCase +from unittest.mock import patch + +from fastapi.testclient import TestClient + +_TEST_CONFIG = { + "mq_default_timeout": 10, + "access_token_ttl": 86400, # 1 day + "refresh_token_ttl": 604800, # 1 week + "requests_per_minute": 60, + "auth_requests_per_minute": 60, + "access_token_secret": "a800445648142061fc238d1f84e96200da87f4f9f784108ac90db8b4391b117b", + "refresh_token_secret": "833d369ac73d883123743a44b4a7fe21203cffc956f4c8a99be6e71aafa8e1aa", + "server_host": "0.0.0.0", + "server_port": 8080, + "fastapi_title": "Test Client Title", + "fastapi_summary": "Test Client Summary", + "stt_max_length_encoded": 500000, + "tts_max_words": 128, + "enable_email": True +} + + +class TestHanaApp(TestCase): + test_app: TestClient = None + tokens: dict = None + + @classmethod + @patch("ovos_config.config.Configuration") + @patch("neon_hana.mq_websocket_api.MQWebsocketAPI") + def setUpClass(cls, ws_api, config): + config.return_value = {"hana": _TEST_CONFIG} + from neon_hana.app import create_app + app = create_app(_TEST_CONFIG) + cls.test_app = TestClient(app) + + def _get_tokens(self): + if not self.tokens: + response = self.test_app.post("/auth/login", + json={"username": "guest", + "password": "password"}) + self.tokens = response.json() + self.assertIn("access_token", self.tokens, self.tokens) + return self.tokens + + def test_app_init(self): + self.assertEqual(self.test_app.app.title, _TEST_CONFIG["fastapi_title"]) + self.assertEqual(self.test_app.app.summary, + _TEST_CONFIG["fastapi_summary"]) + + @patch("neon_hana.mq_service_api.send_mq_request") + def test_auth_login(self, send_request): + send_request.return_value = {} # TODO: Define valid login + + # Valid Login + response = self.test_app.post("/auth/login", + json={"username": "guest", + "password": "password"}) + response_data = response.json() + self.assertEqual(response.status_code, 200, response.text) + self.assertEqual(response_data['username'], "guest") + self.assertIsInstance(response_data['access_token'], str) + self.assertIsInstance(response_data['refresh_token'], str) + self.assertGreater(response_data['expiration'], time()) + + # Invalid Login + # TODO: Define invalid login request + + # Invalid Request + self.assertEqual(self.test_app.post("/auth/login").status_code, 422) + self.assertEqual(self.test_app.post("/auth/login", + json={"username": None}).status_code, + 422) + + @patch("neon_hana.mq_service_api.send_mq_request") + def test_auth_refresh(self, send_request): + send_request.return_value = {} # TODO: Define valid refresh + + valid_tokens = self._get_tokens() + + # Valid request + response = self.test_app.post("/auth/refresh", json=valid_tokens) + self.assertEqual(response.status_code, 200, response.text) + response_data = response.json() + self.assertNotEqual(response_data, valid_tokens) + + # # TODO + # # Refresh with old tokens fails + # response = self.test_app.post("/auth/refresh", json=valid_tokens) + # self.assertEqual(response.status_code, 422, response.text) + + # Valid request with new tokens + response = self.test_app.post("/auth/refresh", json=response_data) + self.assertEqual(response.status_code, 200, response.text) + + # TODO: Test with expired token + + @patch("neon_hana.mq_service_api.send_mq_request") + def test_assist_get_stt(self, send_request): + send_request.return_value = {"data": {"transcripts": ["test"], + "parser_data": {"test": True}}} + + token = self._get_tokens()["access_token"] + # Valid request + response = self.test_app.post("/neon/get_stt", + json={"encoded_audio": "MOCK_B64_AUDIO", + "lang_code": "en-us"}, + headers={"Authorization": f"Bearer {token}"}) + self.assertEqual(response.status_code, 200, response.text) + self.assertEqual(response.json(), send_request.return_value['data']) + + # Invalid missing auth + response = self.test_app.post("/neon/get_stt", + json={"encoded_audio": "MOCK_B64_AUDIO", + "lang_code": "en-us"}) + self.assertEqual(response.status_code, 403, response.text) + + # Invalid request + self.assertEqual(self.test_app.post( + "/neon/get_stt", + headers={"Authorization": f"Bearer {token}"}).status_code, + 422, response.text) + + @patch("neon_hana.mq_service_api.send_mq_request") + def test_assist_get_tts(self, send_request): + send_request.return_value = {"data": { + "en-us": {"audio": {"female": "MOCK_B64_AUDIO"}}}} + + token = self._get_tokens()["access_token"] + # Valid request + response = self.test_app.post("/neon/get_tts", + json={"to_speak": "test", + "lang_code": "en-us"}, + headers={"Authorization": f"Bearer {token}"}) + self.assertEqual(response.status_code, 200, response.text) + self.assertEqual(response.json()['encoded_audio'], "MOCK_B64_AUDIO") + + # Invalid missing auth + response = self.test_app.post("/neon/get_tts", + json={"to_speak": "test", + "lang_code": "en-us"}) + self.assertEqual(response.status_code, 403, response.text) + + # Invalid request + self.assertEqual(self.test_app.post( + "/neon/get_tts", + headers={"Authorization": f"Bearer {token}"}).status_code, + 422, response.text) + + @patch("neon_hana.mq_service_api.send_mq_request") + def test_assist_get_response(self, send_request): + send_request.return_value = { + "data": {"responses": {"en-us": {"sentence": "mock_response"}}}, + "context": {"session": {"new_session": True}}} + + token = self._get_tokens()["access_token"] + # Valid request + response = self.test_app.post("/neon/get_response", + json={"utterance": "test", + "lang_code": "en-us"}, + headers={"Authorization": f"Bearer {token}"}) + self.assertEqual(response.status_code, 200, response.text) + self.assertEqual(response.json()['answer'], "mock_response") + self.assertEqual(response.json()['lang_code'], "en-us") + + # Invalid missing auth + response = self.test_app.post("/neon/get_response", + json={"utterance": "test", + "lang_code": "en-us"}) + self.assertEqual(response.status_code, 403, response.text) + + # Invalid request + self.assertEqual(self.test_app.post( + "/neon/get_response", + headers={"Authorization": f"Bearer {token}"}).status_code, + 422, response.text) + + @patch("neon_hana.mq_service_api.send_mq_request") + def test_proxy_weather(self, send_request): + send_request.return_value = {"status_code": 200, + "content": json.dumps( + {"lat": 47.6815, + "lon": -122.2087, + "timezone": "America/Los_Angeles", + "timezone_offset": -28800, + "current": {}, + "minutely": [], + "hourly": [], + "daily": []})} + valid_request = {"lat": 47.6815, + "lon": -122.2087, + "unit": "metric"} + + token = self._get_tokens()["access_token"] + # Valid request + response = self.test_app.post("/proxy/weather", + json=valid_request, + headers={"Authorization": f"Bearer {token}"}) + self.assertEqual(response.status_code, 200, response.text) + self.assertEqual(response.json(), + json.loads(send_request.return_value['content']), + response.json()) + + # Invalid missing auth + response = self.test_app.post("/proxy/weather", + json=valid_request) + self.assertEqual(response.status_code, 403, response.text) + + # Invalid request + self.assertEqual(self.test_app.post( + "/proxy/weather", + headers={"Authorization": f"Bearer {token}"}).status_code, + 422, response.text) + + @patch("neon_hana.mq_service_api.send_mq_request") + def test_proxy_stock_symbol(self, send_request): + send_request.return_value = {"status_code": 200, + "content": json.dumps( + {"bestMatches": []})} + valid_request = {"company": "microsoft", + "region": "United States"} + + token = self._get_tokens()["access_token"] + # Valid request + response = self.test_app.post("/proxy/stock/symbol", + json=valid_request, + headers={"Authorization": f"Bearer {token}"}) + self.assertEqual(response.status_code, 200, response.text) + self.assertEqual(response.json()['bestMatches'], + json.loads(send_request.return_value['content'])['bestMatches'], + response.json()) + self.assertEqual(response.json()['provider'], "alpha_vantage") + + # Invalid missing auth + response = self.test_app.post("/proxy/stock/symbol", + json=valid_request) + self.assertEqual(response.status_code, 403, response.text) + + # Invalid request + self.assertEqual(self.test_app.post( + "/proxy/stock/symbol", + headers={"Authorization": f"Bearer {token}"}).status_code, + 422, response.text) + + # TODO test region filtering + + @patch("neon_hana.mq_service_api.send_mq_request") + def test_proxy_stock_quote(self, send_request): + send_request.return_value = {"status_code": 200, + "content": json.dumps( + {"Global Quote": {"test": "True"}})} + valid_request = {"symbol": "GOOG"} + + token = self._get_tokens()["access_token"] + # Valid request + response = self.test_app.post("/proxy/stock/quote", + json=valid_request, + headers={"Authorization": f"Bearer {token}"}) + self.assertEqual(response.status_code, 200, response.text) + self.assertEqual(response.json()["Global Quote"], + json.loads(send_request.return_value['content'])["Global Quote"], + response.json()) + + # Invalid missing auth + response = self.test_app.post("/proxy/stock/quote", + json=valid_request) + self.assertEqual(response.status_code, 403, response.text) + + # Invalid request + self.assertEqual(self.test_app.post( + "/proxy/stock/quote", + headers={"Authorization": f"Bearer {token}"}).status_code, + 422, response.text) + + @patch("neon_hana.mq_service_api.send_mq_request") + def test_proxy_geocode(self, send_request): + send_request.return_value = {"status_code": 200, + "content": json.dumps( + {"place_id": 0, + "licence": "test", + "osm_type": "test", + "osm_id": 0, + "boundingbox": ["0", "0", "0", "0"], + "lat": "47.6815", + "lon": "-122.2087", + "display_name": "test", + "class": "amenity", + "type": "post_office", + "importance": 1.0, + "alternate_results": []})} + valid_request = {"address": "1100 Bellevue Way NE Bellevue, WA"} + + token = self._get_tokens()["access_token"] + # Valid request + response = self.test_app.post("/proxy/geolocation/geocode", + json=valid_request, + headers={"Authorization": f"Bearer {token}"}) + self.assertEqual(response.status_code, 200, response.text) + self.assertEqual(response.json(), + json.loads(send_request.return_value['content']), + response.json()) + + # Invalid missing auth + response = self.test_app.post("/proxy/geolocation/geocode", + json=valid_request) + self.assertEqual(response.status_code, 403, response.text) + + # Invalid request + self.assertEqual(self.test_app.post( + "/proxy/geolocation/geocode", + headers={"Authorization": f"Bearer {token}"}).status_code, + 422, response.text) + + @patch("neon_hana.mq_service_api.send_mq_request") + def test_proxy_geocode_reverse(self, send_request): + send_request.return_value = {"status_code": 200, + "content": json.dumps( + {"place_id": 0, + "licence": "test", + "osm_type": "test", + "osm_id": 0, + "boundingbox": ["0", "0", "0", "0"], + "lat": "47.6815", + "lon": "-122.2087", + "display_name": "test", + "address": {}})} + + valid_request = {"lat": 47.6815, "lon": -122.2087} + + token = self._get_tokens()["access_token"] + # Valid request + response = self.test_app.post("/proxy/geolocation/reverse", + json=valid_request, + headers={"Authorization": f"Bearer {token}"}) + self.assertEqual(response.status_code, 200, response.text) + self.assertEqual(response.json(), + json.loads(send_request.return_value['content']), + response.json()) + + # Invalid missing auth + response = self.test_app.post("/proxy/geolocation/reverse", + json=valid_request) + self.assertEqual(response.status_code, 403, response.text) + + # Invalid request + self.assertEqual(self.test_app.post( + "/proxy/geolocation/reverse", + headers={"Authorization": f"Bearer {token}"}).status_code, + 422, response.text) + + @patch("neon_hana.mq_service_api.send_mq_request") + def test_proxy_wolfram(self, send_request): + send_request.return_value = {"status_code": 200, + "content": "answer"} + valid_request = {"api": "spoken", "lat": 47.6815, "lon": -122.2087, + "query": "how far away is the moon"} + + token = self._get_tokens()["access_token"] + # Valid request + response = self.test_app.post("/proxy/wolframalpha", + json=valid_request, + headers={"Authorization": f"Bearer {token}"}) + self.assertEqual(response.status_code, 200, response.text) + self.assertEqual(response.json(), + {"answer": send_request.return_value['content']}, + response.json()) + + # Invalid missing auth + response = self.test_app.post("/proxy/wolframalpha", + json=valid_request) + self.assertEqual(response.status_code, 403, response.text) + + # Invalid request + self.assertEqual(self.test_app.post( + "/proxy/wolframalpha", + headers={"Authorization": f"Bearer {token}"}).status_code, + 422, response.text) + + @patch("neon_hana.mq_service_api.send_mq_request") + def test_backend_email(self, send_request): + send_request.return_value = {"success": True} + valid_request = {"recipient": "developers@neon.ai", + "subject": "API test", + "body": "This is a test.\nGenerated from OpenAPI.", + "attachments": { + "test.txt": "VGhpcyBpcyBhIHRlc3QgZmlsZQo="}} + + token = self._get_tokens()["access_token"] + # Valid request + response = self.test_app.post("/email", + json=valid_request, + headers={"Authorization": f"Bearer {token}"}) + self.assertEqual(response.status_code, 200, response.text) + + # Invalid missing auth + response = self.test_app.post("/email", + json=valid_request) + self.assertEqual(response.status_code, 403, response.text) + + # Invalid request + self.assertEqual(self.test_app.post( + "/email", + headers={"Authorization": f"Bearer {token}"}).status_code, + 422, response.text) + + # Valid request failed + send_request.return_value = {"success": False, + "error": "Something has failed"} + response = self.test_app.post("/email", + json=valid_request, + headers={"Authorization": f"Bearer {token}"}) + self.assertEqual(response.status_code, 500, response.text) + self.assertEqual(response.json()['detail'], "Something has failed") + + # TODO: Test disabled service + + @patch("neon_hana.mq_service_api.send_mq_request") + def test_backend_metrics(self, send_request): + send_request.return_value = {} + valid_request = {"metric_name": "Unit Test", + "timestamp": str(time()), + "metric_data": {"test": True}} + token = self._get_tokens()["access_token"] + + # Valid request + response = self.test_app.post("/metrics/upload", + json=valid_request, + headers={"Authorization": f"Bearer {token}"}) + self.assertEqual(response.status_code, 200, response.text) + + # Invalid missing auth + response = self.test_app.post("/metrics/upload", + json=valid_request) + self.assertEqual(response.status_code, 403, response.text) + + # Invalid request + self.assertEqual(self.test_app.post( + "/metrics/upload", + headers={"Authorization": f"Bearer {token}"}).status_code, + 422, response.text) + + @patch("neon_hana.mq_service_api.send_mq_request") + def test_backend_ccl(self, send_request): + send_request.return_value = {"parsed_file": "MOCK_NCS_DATA"} + valid_request = {"script": "MOCK_SCRIPT_DATA"} + token = self._get_tokens()["access_token"] + + # Valid request + response = self.test_app.post("/ccl/parse", + json=valid_request, + headers={"Authorization": f"Bearer {token}"}) + self.assertEqual(response.status_code, 200, response.text) + self.assertEqual(response.json()['ncs'], "MOCK_NCS_DATA") + + # Invalid missing auth + response = self.test_app.post("/ccl/parse", + json=valid_request) + self.assertEqual(response.status_code, 403, response.text) + + # Invalid request + self.assertEqual(self.test_app.post( + "/ccl/parse", + headers={"Authorization": f"Bearer {token}"}).status_code, + 422, response.text) + + @patch("neon_hana.mq_service_api.send_mq_request") + def test_backend_coupons(self, send_request): + send_request.return_value = {"success": True, "brands": [], + "coupons": []} + token = self._get_tokens()["access_token"] + + # Valid request + response = self.test_app.post("/coupons", + headers={"Authorization": f"Bearer {token}"}) + self.assertEqual(response.status_code, 200, response.text) + self.assertEqual(response.json(), send_request.return_value) + + # Invalid missing auth + response = self.test_app.post("/coupons") + self.assertEqual(response.status_code, 403, response.text) + + @patch("neon_hana.mq_service_api.send_mq_request") + def test_llm(self, send_request): + send_request.return_value = {"response": "MOCK_LLM_RESPONSE"} + valid_request = {"query": "how are you?", + "history": [("user", "hello"), + ("llm", "Hi, how can I help you today?")]} + # Responses are lists instead of tuples because Pydantic will auto-cast + # for JSON encoding + valid_response = {"response": "MOCK_LLM_RESPONSE", + "history": [["user", "hello"], + ["llm", "Hi, how can I help you today?"], + ["user", "how are you?"], + ["llm", "MOCK_LLM_RESPONSE"]]} + token = self._get_tokens()["access_token"] + # ChatGPT + response = self.test_app.post("/llm/chatgpt", + json=valid_request, + headers={"Authorization": f"Bearer {token}"}) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), valid_response) + + # Fastchat + response = self.test_app.post("/llm/fastchat", + json=valid_request, + headers={"Authorization": f"Bearer {token}"}) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), valid_response) + + # Claude + response = self.test_app.post("/llm/claude", + json=valid_request, + headers={"Authorization": f"Bearer {token}"}) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), valid_response) + + # Palm + response = self.test_app.post("/llm/palm", + json=valid_request, + headers={"Authorization": f"Bearer {token}"}) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), valid_response) + + # Invalid requests + response = self.test_app.post("/llm/chatgpt", + json=valid_request) + self.assertEqual(response.status_code, 403, response.text) + + response = self.test_app.post("/llm/chatgpt", + headers={"Authorization": f"Bearer {token}"}) + self.assertEqual(response.status_code, 422, response.text) + + def test_util_is_ipv4(self): + from neon_hana.app.routers.util import _is_ipv4 + self.assertTrue(_is_ipv4("127.0.0.1")) + self.assertTrue(_is_ipv4("10.0.0.10")) + self.assertTrue(_is_ipv4("1.1.1.1")) + self.assertFalse(_is_ipv4("ai.neon.api.1")) + self.assertFalse(_is_ipv4("host.local")) + self.assertFalse(_is_ipv4("localhost")) + self.assertFalse(_is_ipv4("1.0.0.300")) + + def test_util_client_ip(self): + response = self.test_app.get("/util/client_ip") + self.assertEqual(response.text, "127.0.0.1") + + def test_util_headers(self): + test_headers = {"X-Auth-Token": "Token", + "Authorization": "Test Auth", + "My Custom Header": "Value"} + response = self.test_app.get("/util/headers", headers=test_headers) + for key, val in test_headers.items(): + self.assertEqual(response.json()[key.lower()], val, response.json()) + +# TODO: Define node endpoint tests