diff --git a/app.py b/app.py index 94373df..67b8c48 100644 --- a/app.py +++ b/app.py @@ -1,22 +1,37 @@ import os +from typing import Optional, List from logging import getLogger -from fastapi import FastAPI, Response, status +from fastapi import FastAPI, Depends, Response, status +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from typing import Union -from config import TRUST_REMOTE_CODE +from config import TRUST_REMOTE_CODE, get_allowed_tokens from vectorizer import Vectorizer, VectorInput from meta import Meta -app = FastAPI() +logger = getLogger("uvicorn") + vec: Vectorizer meta_config: Meta -logger = getLogger("uvicorn") + +get_bearer_token = HTTPBearer(auto_error=False) +allowed_tokens: List[str] = None + + +def is_authorized(auth: Optional[HTTPAuthorizationCredentials]) -> bool: + if allowed_tokens is not None and ( + auth is None or auth.credentials not in allowed_tokens + ): + return False + return True -@app.on_event("startup") -def startup_event(): +async def lifespan(app: FastAPI): global vec global meta_config + global allowed_tokens + + allowed_tokens = get_allowed_tokens() cuda_env = os.getenv("ENABLE_CUDA") cuda_per_process_memory_fraction = 1.0 @@ -113,6 +128,10 @@ def log_info_about_onnx(onnx_runtime: bool): model_name, trust_remote_code, ) + yield + + +app = FastAPI(lifespan=lifespan) @app.get("/.well-known/live", response_class=Response) @@ -122,17 +141,32 @@ async def live_and_ready(response: Response): @app.get("/meta") -def meta(): - return meta_config.get() +def meta( + response: Response, + auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token), +): + if is_authorized(auth): + return meta_config.get() + else: + response.status_code = status.HTTP_401_UNAUTHORIZED + return {"error": "Unauthorized"} @app.post("/vectors") @app.post("/vectors/") -async def vectorize(item: VectorInput, response: Response): - try: - vector = await vec.vectorize(item.text, item.config) - return {"text": item.text, "vector": vector.tolist(), "dim": len(vector)} - except Exception as e: - logger.exception("Something went wrong while vectorizing data.") - response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR - return {"error": str(e)} +async def vectorize( + item: VectorInput, + response: Response, + auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token), +): + if is_authorized(auth): + try: + vector = await vec.vectorize(item.text, item.config) + return {"text": item.text, "vector": vector.tolist(), "dim": len(vector)} + except Exception as e: + logger.exception("Something went wrong while vectorizing data.") + response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR + return {"error": str(e)} + else: + response.status_code = status.HTTP_401_UNAUTHORIZED + return {"error": "Unauthorized"} diff --git a/cicd/test.sh b/cicd/test.sh index 4fd5c4b..adfc8e6 100755 --- a/cicd/test.sh +++ b/cicd/test.sh @@ -6,7 +6,17 @@ local_repo=${LOCAL_REPO?Variable LOCAL_REPO is required} pip3 install -r requirements-test.txt -docker run -d -it -p "8000:8080" "$local_repo" +echo "Running tests with authorization on" + +container_id=$(docker run -d -it -e AUTHENTICATION_ALLOWED_TOKENS='token1,token2' -p "8000:8080" "$local_repo") + +python3 smoke_auth_test.py + +docker stop $container_id + +echo "Running tests without authorization" + +container_id=$(docker run -d -it -p "8000:8080" "$local_repo") python3 smoke_test.py pytest test_app.py diff --git a/config.py b/config.py index f7d8dd6..467f504 100644 --- a/config.py +++ b/config.py @@ -1,3 +1,11 @@ import os +from typing import List TRUST_REMOTE_CODE = os.getenv("TRUST_REMOTE_CODE", False) + + +def get_allowed_tokens() -> List[str] | None: + if ( + tokens := os.getenv("AUTHENTICATION_ALLOWED_TOKENS", "").strip() + ) and tokens != "": + return tokens.strip().split(",") diff --git a/smoke_auth_test.py b/smoke_auth_test.py new file mode 100644 index 0000000..2ef03fc --- /dev/null +++ b/smoke_auth_test.py @@ -0,0 +1,93 @@ +import time +import unittest +import requests + + +class SmokeTest(unittest.TestCase): + def setUp(self): + self.url = "http://localhost:8000" + + for i in range(0, 100): + try: + res = requests.get(self.url + "/.well-known/ready") + if res.status_code == 204: + return + else: + raise Exception("status code is {}".format(res.status_code)) + except Exception as e: + print("Attempt {}: {}".format(i, e)) + time.sleep(1) + + raise Exception("did not start up") + + def test_well_known_ready(self): + res = requests.get(self.url + "/.well-known/ready") + + self.assertEqual(res.status_code, 204) + + def test_well_known_live(self): + res = requests.get(self.url + "/.well-known/live") + + self.assertEqual(res.status_code, 204) + + def test_meta_unauthorized(self): + res = requests.get(self.url + "/meta") + + self.assertEqual(res.status_code, 401) + self.assertEqual(res.json()["error"], "Unauthorized") + + headers = {"Authorization": "Bearer bad-token"} + res = requests.get(self.url + "/meta", headers=headers) + + self.assertEqual(res.status_code, 401) + self.assertEqual(res.json()["error"], "Unauthorized") + + def test_meta(self): + headers = {"Authorization": "Bearer token1"} + res = requests.get(self.url + "/meta", headers=headers) + + self.assertEqual(res.status_code, 200) + self.assertIsInstance(res.json(), dict) + + def test_vectorizing_unauthorized(self): + req_body = {"text": "The London Eye is a ferris wheel at the River Thames."} + res = requests.post(self.url + "/vectors", json=req_body) + + self.assertEqual(res.status_code, 401) + self.assertEqual(res.json()["error"], "Unauthorized") + + headers = {"Authorization": "Bearer bad-token"} + res = requests.post(self.url + "/vectors", json=req_body, headers=headers) + + self.assertEqual(res.status_code, 401) + self.assertEqual(res.json()["error"], "Unauthorized") + + def test_vectorizing(self): + def get_req_body(task_type: str = ""): + req_body = {"text": "The London Eye is a ferris wheel at the River Thames."} + if task_type != "": + req_body["config"] = {"task_type": task_type} + return req_body + + def try_to_vectorize(url, task_type: str = ""): + print(f"url: {url}") + req_body = get_req_body(task_type) + + headers = {"Authorization": "Bearer token2"} + res = requests.post(url, json=req_body, headers=headers) + resBody = res.json() + + self.assertEqual(200, res.status_code) + + # below tests that what we deem a reasonable vector is returned. We are + # aware of 384 and 768 dim vectors, which should both fall in that + # range + self.assertTrue(len(resBody["vector"]) > 100) + print(f"vector dimensions are: {len(resBody['vector'])}") + + try_to_vectorize(self.url + "/vectors/") + try_to_vectorize(self.url + "/vectors") + + +if __name__ == "__main__": + unittest.main()