Skip to content

Commit

Permalink
test server in background proc
Browse files Browse the repository at this point in the history
  • Loading branch information
MayNiklas committed Jan 2, 2025
1 parent d1ed3a5 commit b10cc32
Showing 1 changed file with 22 additions and 20 deletions.
42 changes: 22 additions & 20 deletions test/test_api.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import os
import time
import unittest
from multiprocessing import Process
from typing import Tuple

import httpx
import torch
from fastapi.testclient import TestClient
import uvicorn

from whisper_api import app

Expand All @@ -14,25 +13,15 @@
"""


# remove the comment to run the tests against a local server
# os.environ["test_base_url"] = "http://127.0.0.1:3001"


def do_test() -> Tuple[bool, str]:
"""
Decide whether to run the tests or not.
"""
if not torch.cuda.is_available() and (os.environ.get("test_base_url") is None):
return False, "CUDA is not available and test_base_url is not set"
return True, ""
return False, "This tests are currently disabled"


# if env test_base_url is set, use that as the base url
# export test_base_url=http://127.0.0.1:3001
if os.environ.get("test_base_url") is not None:
client = httpx.Client(base_url=os.environ["test_base_url"])
else:
client = TestClient(app)
def run_server():
uvicorn.run(app, port=10291)


class TestAPI(unittest.TestCase):
Expand All @@ -42,17 +31,30 @@ class TestAPI(unittest.TestCase):

do_test, reason = do_test()

proc = Process(target=run_server, args=(), daemon=False)
client = httpx.Client(base_url="http://127.0.0.1:10291")

@classmethod
def setUpClass(cls):
cls.proc.start()

@classmethod
def tearDownClass(cls):
cls.proc.kill()

@unittest.skipIf(not do_test, reason)
def test_loaded_model(self):
"""
Test that the API is reachable and the model is loaded within 120 seconds.
"""

time.sleep(5)

timeout = 120
start_time = time.time()

while time.time() - start_time < timeout:
response = client.get("/api/v1/decoder_status")
response = self.client.get("/api/v1/decoder_status")
if response.json().get("is_model_loaded") == True:
break
print("Waiting for model to load...")
Expand All @@ -69,18 +71,18 @@ def test_transcribe(self):
"""
file = open("test/files/En-Open_Source_Software_CD-article.ogg", "rb")
files = {"file": file}
response = client.post("/api/v1/transcribe", files=files)
response = self.client.post("/api/v1/transcribe", files=files)
file.close()

self.assertEqual(response.status_code, 200)

print(response.json())

timeout = 30
timeout = 60
start_time = time.time()

while time.time() - start_time < timeout:
response = client.get(f"/api/v1/status?task_id={response.json().get("task_id")}")
response = self.client.get(f"/api/v1/status?task_id={response.json().get("task_id")}")
if (response.json().get("status")) == "finished":
break
print("Waiting for transcription to complete...")
Expand Down

0 comments on commit b10cc32

Please sign in to comment.