diff --git a/python/main.py b/python/main.py index 7347d76..a8f1f4e 100644 --- a/python/main.py +++ b/python/main.py @@ -22,11 +22,13 @@ import numpy as np import pandas as pd import pydicom +import pytest import torch -from fastapi import Body, FastAPI, Request, UploadFile +from fastapi import Body, FastAPI, Request, UploadFile, status from fastapi.responses import HTMLResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates +from fastapi.testclient import TestClient from PIL import Image from pydantic import BaseModel from pydicom.errors import InvalidDicomError @@ -159,6 +161,22 @@ def dcm2dictmetadata(ds: pydicom.dataset.Dataset) -> dict[str, dict[str, str]]: templates = Jinja2Templates(directory="templates") app.mount(path="/static", app=StaticFiles(directory="static"), name="static") +client = TestClient(app) + + +def app_url() -> str: + return "http://0.0.0.0:8000" + + +def test_upload_files() -> None: + with Path("./prm/sample.dcm").open("rb") as file: + files = {"files": ("./prm/sample.dcm", file, "application/dicom")} + response = client.post(app_url() + "/upload_files", files=files) + if response.status_code != status.HTTP_200_OK: + raise AssertionError + response.json() + UploadFilesResponse.model_validate(response.json()) + @app.get("/", response_class=HTMLResponse) async def get_root(request: Request) -> HTMLResponse: @@ -1118,14 +1136,14 @@ def ndarray_size(arr: NDArray[Any]) -> int: if __name__ == "__main__": + tmp_directories = [ + Path("tmp/session-data/raw"), + Path("tmp/session-data/clean"), + Path("tmp/session-data/embed"), + ] + for directory in tmp_directories: + directory.mkdir(parents=True, exist_ok=True) if os.getenv("STAGING"): - tmp_directories = [ - Path("tmp/session-data/raw"), - Path("tmp/session-data/clean"), - Path("tmp/session-data/embed"), - ] - for directory in tmp_directories: - directory.mkdir(parents=True, exist_ok=True) if not Path("tmp/fullchain.pem").exists(): subprocess.run( [ # noqa: S603 @@ -1149,3 +1167,7 @@ def ndarray_size(arr: NDArray[Any]) -> int: ssl_certfile="tmp/fullchain.pem", ssl_keyfile="tmp/privkey.pem", ) + else: + results = pytest.main(["-rA", "-o", "cache_dir=tmp", __file__]) + if results.value != 0: # type: ignore[attr-defined] + sys.exit(results) diff --git a/python/prm/sample.dcm b/python/prm/sample.dcm new file mode 100755 index 0000000..2d3aca2 Binary files /dev/null and b/python/prm/sample.dcm differ diff --git a/python/pyproject.toml b/python/pyproject.toml index d08278e..e52338d 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -5,6 +5,7 @@ dependencies = [ "pydicom==2.4.4", "pylibjpeg-libjpeg==2.0.2", "pylibjpeg==2.0.0", + "pytest==8.0.2", "python-multipart==0.0.9", "segment-anything@git+https://github.com/facebookresearch/segment-anything.git@6fdee8f2727f4506cfbbe553e23b895e27956588", "tensorflow==2.15.0.post1",