diff --git a/test/test_api.py b/test/test_api.py index d6f6c2c..6ddd48a 100644 --- a/test/test_api.py +++ b/test/test_api.py @@ -1,11 +1,10 @@ -import os import time import unittest +from multiprocessing import Process from typing import Tuple import httpx -import torch -from fastapi.testclient import TestClient +import uvicorn from whisper_api import app @@ -14,25 +13,15 @@ """ -# remove the comment to run the tests against a local server -# os.environ["test_base_url"] = "http://127.0.0.1:3001" - - def do_test() -> Tuple[bool, str]: """ Decide whether to run the tests or not. """ - if not torch.cuda.is_available() and (os.environ.get("test_base_url") is None): - return False, "CUDA is not available and test_base_url is not set" - return True, "" + return False, "This tests are currently disabled" -# if env test_base_url is set, use that as the base url -# export test_base_url=http://127.0.0.1:3001 -if os.environ.get("test_base_url") is not None: - client = httpx.Client(base_url=os.environ["test_base_url"]) -else: - client = TestClient(app) +def run_server(): + uvicorn.run(app, port=10291) class TestAPI(unittest.TestCase): @@ -42,17 +31,30 @@ class TestAPI(unittest.TestCase): do_test, reason = do_test() + proc = Process(target=run_server, args=(), daemon=False) + client = httpx.Client(base_url="http://127.0.0.1:10291") + + @classmethod + def setUpClass(cls): + cls.proc.start() + + @classmethod + def tearDownClass(cls): + cls.proc.kill() + @unittest.skipIf(not do_test, reason) def test_loaded_model(self): """ Test that the API is reachable and the model is loaded within 120 seconds. """ + time.sleep(5) + timeout = 120 start_time = time.time() while time.time() - start_time < timeout: - response = client.get("/api/v1/decoder_status") + response = self.client.get("/api/v1/decoder_status") if response.json().get("is_model_loaded") == True: break print("Waiting for model to load...") @@ -69,18 +71,18 @@ def test_transcribe(self): """ file = open("test/files/En-Open_Source_Software_CD-article.ogg", "rb") files = {"file": file} - response = client.post("/api/v1/transcribe", files=files) + response = self.client.post("/api/v1/transcribe", files=files) file.close() self.assertEqual(response.status_code, 200) print(response.json()) - timeout = 30 + timeout = 60 start_time = time.time() while time.time() - start_time < timeout: - response = client.get(f"/api/v1/status?task_id={response.json().get("task_id")}") + response = self.client.get(f"/api/v1/status?task_id={response.json().get("task_id")}") if (response.json().get("status")) == "finished": break print("Waiting for transcription to complete...")