From 32727462971df168658705b2015fd24661ea3c17 Mon Sep 17 00:00:00 2001 From: Jacob Filik Date: Mon, 24 Jun 2024 16:11:54 +0100 Subject: [PATCH] add more unit test (#13) * add more unit test * missing ruff --- .../src/xas_standards_api/crud.py | 35 +++-- .../src/xas_standards_api/routers/admin.py | 19 ++- .../src/xas_standards_api/routers/open.py | 8 +- xas-standards-api/tests/test_admin_router.py | 69 ++++++++++ xas-standards-api/tests/test_app.py | 102 +-------------- xas-standards-api/tests/test_crud.py | 4 +- xas-standards-api/tests/test_open_router.py | 79 ++++++++++++ .../tests/test_protected_router.py | 58 +++++++++ xas-standards-api/tests/utils.py | 121 ++++++++++++++++++ 9 files changed, 377 insertions(+), 118 deletions(-) create mode 100644 xas-standards-api/tests/test_admin_router.py create mode 100644 xas-standards-api/tests/test_open_router.py create mode 100644 xas-standards-api/tests/test_protected_router.py create mode 100644 xas-standards-api/tests/utils.py diff --git a/xas-standards-api/src/xas_standards_api/crud.py b/xas-standards-api/src/xas_standards_api/crud.py index 27a6f41..a83442f 100644 --- a/xas-standards-api/src/xas_standards_api/crud.py +++ b/xas-standards-api/src/xas_standards_api/crud.py @@ -50,7 +50,6 @@ def read_standards_page( session: Session, element: str | None = None, ) -> CursorPage[XASStandardResponse]: - statement = select(XASStandard).where( XASStandard.review_status == ReviewStatus.approved ) @@ -89,7 +88,10 @@ def get_metadata(session): def get_standard(session, id) -> XASStandard: standard = session.get(XASStandard, id) + if standard: + if standard.review_status != ReviewStatus.approved: + raise HTTPException(status_code=401, detail="Standard not available") return standard else: raise HTTPException(status_code=404, detail=f"No standard with id={id}") @@ -132,7 +134,6 @@ def select_or_create_person(session, identifier): def add_new_standard(session, file1, xs_input: XASStandardInput, additional_files): - tmp_filename = pvc_location + str(uuid.uuid4()) with open(tmp_filename, "wb") as ntf: @@ -166,7 +167,6 @@ def add_new_standard(session, file1, xs_input: XASStandardInput, additional_file def get_filepath(session, id): - standard = session.get(XASStandard, id) if not standard: raise HTTPException(status_code=404, detail=f"No standard with id={id}") @@ -186,7 +186,10 @@ def get_file(session, id): return FileResponse(xdi_location) -def get_file_as_text(session, id): +def get_file_as_text(session, id, user_id): + # only admins can see original file + is_admin_user(session, user_id) + xdi_location = get_filepath(session, id) with open(xdi_location) as fh: file = fh.read() @@ -195,7 +198,6 @@ def get_file_as_text(session, id): def get_norm(energy, group, type): - if type in group: r = group[type] tr = set_xafsGroup(None) @@ -208,7 +210,6 @@ def get_norm(energy, group, type): def get_data(session, id): - xdi_location = get_filepath(session, id) xdi_data = xdi.read_xdi(xdi_location) @@ -237,17 +238,25 @@ def get_standards_admin( session: Session, user_id: str, ): + is_admin_user(session, user_id) + + statement = select(XASStandard).where( + XASStandard.review_status == ReviewStatus.pending + ) + + return paginate(session, statement.order_by(XASStandard.id)) + + +def is_admin_user(session: Session, user_id: str): statement = select(Person).where(Person.identifier == user_id) person = session.exec(statement).first() - if person is None or not person.admin: - raise HTTPException(status_code=401, detail=f"No standard with id={user_id}") + if person is None: + raise HTTPException( + status_code=401, detail=f"No person associated with id {user_id}" + ) if not person.admin: raise HTTPException(status_code=401, detail=f"User {user_id} not admin") - statement = select(XASStandard).where( - XASStandard.review_status == ReviewStatus.pending - ) - - return paginate(session, statement.order_by(XASStandard.id)) + return True diff --git a/xas-standards-api/src/xas_standards_api/routers/admin.py b/xas-standards-api/src/xas_standards_api/routers/admin.py index eb57d11..6ff5db8 100644 --- a/xas-standards-api/src/xas_standards_api/routers/admin.py +++ b/xas-standards-api/src/xas_standards_api/routers/admin.py @@ -1,9 +1,11 @@ +from typing import Optional + from fastapi import APIRouter, Depends from fastapi_pagination.cursor import CursorPage from sqlmodel import Session from ..auth import get_current_user -from ..crud import get_file_as_text, get_standards_admin +from ..crud import get_data, get_file, get_file_as_text, get_standards_admin from ..database import get_session from ..models.response_models import AdminXASStandardResponse @@ -11,9 +13,19 @@ @router.get("/api/admin/data/{id}") -async def read_admin_data(id: int, session: Session = Depends(get_session)): +async def read_admin_data( + id: int, + format: Optional[str] = "", + session: Session = Depends(get_session), + user_id: str = Depends(get_current_user), +): + if format == "download": + return get_file(session, id) - return get_file_as_text(session, id) + if format == "json": + return get_data(session, id) + + return get_file_as_text(session, id, user_id) @router.get("/api/admin/standards") @@ -21,5 +33,4 @@ def read_standards_admin( session: Session = Depends(get_session), user_id: str = Depends(get_current_user), ) -> CursorPage[AdminXASStandardResponse]: - return get_standards_admin(session, user_id) diff --git a/xas-standards-api/src/xas_standards_api/routers/open.py b/xas-standards-api/src/xas_standards_api/routers/open.py index 48747b9..529e445 100644 --- a/xas-standards-api/src/xas_standards_api/routers/open.py +++ b/xas-standards-api/src/xas_standards_api/routers/open.py @@ -1,11 +1,12 @@ from typing import Optional -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, HTTPException from fastapi_pagination.cursor import CursorPage from sqlmodel import Session from ..crud import get_data, get_file, get_metadata, get_standard, read_standards_page from ..database import get_session +from ..models.models import ReviewStatus from ..models.response_models import ( MetadataResponse, XASStandardResponse, @@ -31,7 +32,6 @@ def read_standards( session: Session = Depends(get_session), element: str | None = None, ) -> CursorPage[XASStandardResponse]: - return read_standards_page(session, element) @@ -39,6 +39,10 @@ def read_standards( async def read_data( id: int, format: Optional[str] = "json", session: Session = Depends(get_session) ): + standard = get_standard(session, id) + + if standard.review_status != ReviewStatus.approved: + raise HTTPException(status_code=401, detail="Standard data not available") if format == "xdi": return get_file(session, id) diff --git a/xas-standards-api/tests/test_admin_router.py b/xas-standards-api/tests/test_admin_router.py new file mode 100644 index 0000000..b070424 --- /dev/null +++ b/xas-standards-api/tests/test_admin_router.py @@ -0,0 +1,69 @@ +from fastapi.testclient import TestClient +from sqlmodel import Session, SQLModel, create_engine +from sqlmodel.pool import StaticPool + +from utils import build_test_database +from xas_standards_api.app import app +from xas_standards_api.auth import get_current_user +from xas_standards_api.database import get_session +from xas_standards_api.models.response_models import AdminXASStandardResponse + + +def test_admin_read_permissions(): + engine = create_engine( + "sqlite://", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + SQLModel.metadata.create_all(engine) + + with Session(engine) as session: + build_test_database(session) + + def get_session_override(): + return session + + def get_ordinary_user(): + return "user" + + def get_admin_user(): + return "admin" + + client = TestClient(app) + + # first try with ordinary user + app.dependency_overrides[get_session] = get_session_override + app.dependency_overrides[get_current_user] = get_ordinary_user + + response = client.get("/api/admin/standards") + + assert response.status_code == 401 + + # check cant get data + response = client.get("/api/admin/data/2") + assert response.status_code == 401 + + # check cant get data from open endpoint + response = client.get("/api/data/2") + assert response.status_code == 401 + + # now try admin user + app.dependency_overrides.clear() + app.dependency_overrides[get_session] = get_session_override + app.dependency_overrides[get_current_user] = get_admin_user + + response = client.get("/api/admin/standards") + r = response.json() + + # check response is paginated, containing 1 item + assert "items" in r + assert len(r["items"]) == 1 + + # check its correct response and contains the submitter identifier + axassr = AdminXASStandardResponse.model_validate(r["items"][0]) + assert axassr.submitter.identifier == "user" + + # check can get data + response = client.get("/api/admin/data/2") + + assert response.text.startswith("# XDI") diff --git a/xas-standards-api/tests/test_app.py b/xas-standards-api/tests/test_app.py index 8bbab74..6347dca 100644 --- a/xas-standards-api/tests/test_app.py +++ b/xas-standards-api/tests/test_app.py @@ -1,102 +1,10 @@ from fastapi.testclient import TestClient -from sqlmodel import Session, SQLModel, create_engine -from sqlmodel.pool import StaticPool from xas_standards_api.app import app -from xas_standards_api.auth import get_current_user -from xas_standards_api.database import get_session -from xas_standards_api.models.models import Beamline, Edge, Element, Facility, Person -from xas_standards_api.models.response_models import MetadataResponse -client = TestClient(app) - -def test_read_item(): - engine = create_engine( - "sqlite://", - connect_args={"check_same_thread": False}, - poolclass=StaticPool, - ) - SQLModel.metadata.create_all(engine) - - with Session(engine) as session: - - session.add(Element(name="Hydrogen", z=1, symbol="H")) - session.add(Edge(name="K", id=1, level="sp")) - session.add( - Facility( - id=1, - name="synchrotron", - notes="a place", - fullname="a synchrotron", - city="somewhere", - region="someplace", - laboratory="a lab", - country="somecountry", - ) - ) - - session.add( - Beamline( - facility_id=1, - id=1, - name="my beamline", - notes="a beamline", - xray_source="BM", - ) - ) - session.commit() - - def get_session_override(): - return session - - app.dependency_overrides[get_session] = get_session_override - - client = TestClient(app) - - response = client.get("/api/metadata/") - app.dependency_overrides.clear() - - print(response) - - mr = MetadataResponse.model_validate(response.json()) - - assert response.status_code == 200 - assert mr.elements[0].symbol == "H" - assert mr.edges[0].name == "K" - assert mr.beamlines[0].name == "my beamline" - assert mr.beamlines[0].facility.name == "synchrotron" - - -def test_read_person(): - engine = create_engine( - "sqlite://", - connect_args={"check_same_thread": False}, - poolclass=StaticPool, - ) - SQLModel.metadata.create_all(engine) - - with Session(engine) as session: - - session.add(Person(id=1, identifier="abc123", admin=False)) - - session.commit() - - def get_session_override(): - return session - - def get_current_user_override(): - return "abc123" - - app.dependency_overrides[get_session] = get_session_override - app.dependency_overrides[get_current_user] = get_current_user_override - - client = TestClient(app) - - response = client.get("/api/user/") - app.dependency_overrides.clear() - - r = response.json() - - assert r["user"] == "abc123" - assert not r["admin"] +def test_login_redirect(): + client = TestClient(app) + response = client.get("/login") + # expect 404 since root is not defined in test + assert response.status_code == 404 diff --git a/xas-standards-api/tests/test_crud.py b/xas-standards-api/tests/test_crud.py index f30b088..6535830 100644 --- a/xas-standards-api/tests/test_crud.py +++ b/xas-standards-api/tests/test_crud.py @@ -5,15 +5,15 @@ from sqlmodel import Session from xas_standards_api import crud -from xas_standards_api.models.models import XASStandard +from xas_standards_api.models.models import ReviewStatus, XASStandard def test_get_standard(): - mock_session = create_autospec(Session, instance=True) test_id = 0 result = XASStandard() + result.review_status = ReviewStatus.approved # Session returns None, i.e. no standard for id mock_session.get = Mock(return_value=None) diff --git a/xas-standards-api/tests/test_open_router.py b/xas-standards-api/tests/test_open_router.py new file mode 100644 index 0000000..315043d --- /dev/null +++ b/xas-standards-api/tests/test_open_router.py @@ -0,0 +1,79 @@ +from fastapi.testclient import TestClient +from sqlmodel import Session, SQLModel, create_engine +from sqlmodel.pool import StaticPool + +from utils import build_test_database +from xas_standards_api.app import app +from xas_standards_api.database import get_session +from xas_standards_api.models.response_models import ( + MetadataResponse, + XASStandardResponse, +) + + +def test_read_metadata(): + engine = create_engine( + "sqlite://", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + SQLModel.metadata.create_all(engine) + + with Session(engine) as session: + build_test_database(session) + + def get_session_override(): + return session + + client = TestClient(app) + app.dependency_overrides[get_session] = get_session_override + + response = client.get("/api/metadata/") + + print(response) + + mr = MetadataResponse.model_validate(response.json()) + + assert response.status_code == 200 + assert mr.elements[0].symbol == "H" + assert mr.edges[0].name == "K" + assert mr.beamlines[0].name == "my beamline" + assert mr.beamlines[0].facility.name == "synchrotron" + + response = client.get("/api/standards/") + + rjson = response.json() + + assert "items" in rjson + assert len(rjson["items"]) == 1 + + assert "submitter" not in rjson["items"][0] + + xassr = XASStandardResponse.model_validate(rjson["items"][0]) + + assert xassr.id == 1 + + # check can get data from open endpoint + response = client.get("/api/data/1") + assert response.status_code == 200 + + rjson = response.json() + + assert "mutrans" in rjson + + # check cant get unreviewed data from open endpoint + response = client.get("/api/data/2") + assert response.status_code == 401 + + # check cant get id that doesnt exist + response = client.get("/api/data/3") + assert response.status_code == 404 + + response = client.get("/api/standards/1") + assert response.status_code == 200 + + response = client.get("/api/standards/2") + assert response.status_code == 401 + + response = client.get("/api/standards/3") + assert response.status_code == 404 diff --git a/xas-standards-api/tests/test_protected_router.py b/xas-standards-api/tests/test_protected_router.py new file mode 100644 index 0000000..8441c09 --- /dev/null +++ b/xas-standards-api/tests/test_protected_router.py @@ -0,0 +1,58 @@ +from fastapi.testclient import TestClient +from sqlmodel import Session, SQLModel, create_engine +from sqlmodel.pool import StaticPool + +from utils import build_test_database +from xas_standards_api.app import app +from xas_standards_api.auth import get_current_user +from xas_standards_api.database import get_session + + +def test_read_person(): + engine = create_engine( + "sqlite://", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + SQLModel.metadata.create_all(engine) + + with Session(engine) as session: + build_test_database(session) + + def get_session_override(): + return session + + def get_ordinary_user(): + return "user" + + def get_admin_user(): + return "admin" + + session.commit() + + # check non admin user is non admin user + app.dependency_overrides[get_session] = get_session_override + app.dependency_overrides[get_current_user] = get_ordinary_user + + client = TestClient(app) + + response = client.get("/api/user/") + + r = response.json() + assert r["user"] == "user" + assert not r["admin"] + + # check admin user is admin user + app.dependency_overrides.clear() + app.dependency_overrides[get_session] = get_session_override + app.dependency_overrides[get_current_user] = get_admin_user + + response = client.get("/api/user/") + + r = response.json() + + assert r["user"] == "admin" + assert r["admin"] + + # TODO check post of standard + # TODO check patch of standard diff --git a/xas-standards-api/tests/utils.py b/xas-standards-api/tests/utils.py new file mode 100644 index 0000000..09e35af --- /dev/null +++ b/xas-standards-api/tests/utils.py @@ -0,0 +1,121 @@ +from datetime import datetime + +from sqlmodel import Session + +from xas_standards_api.models.models import ( + Beamline, + Edge, + Element, + Facility, + LicenceType, + Person, + ReviewStatus, + XASStandard, + XASStandardData, +) + + +def build_test_database(session: Session): + session.add(Person(id=1, identifier="admin", admin=True)) + session.add(Person(id=2, identifier="user", admin=False)) + + session.add(Element(name="Hydrogen", z=1, symbol="H")) + session.add(Element(name="Helium", z=2, symbol="He")) + session.add(Element(name="Lithium", z=3, symbol="Li")) + session.add(Element(name="Beryllium", z=4, symbol="Be")) + + session.add(Edge(name="K", id=1, level="1s")) + session.add(Edge(name="L3", id=2, level="2p3/2")) + + session.add( + Facility( + id=1, + name="synchrotron", + notes="a place", + fullname="a synchrotron", + city="somewhere", + region="someplace", + laboratory="a lab", + country="somecountry", + ) + ) + + session.add( + Beamline( + facility_id=1, + id=1, + name="my beamline", + notes="a beamline", + xray_source="BM", + ) + ) + + xas_data1 = XASStandardData( + emission=False, + transmission=True, + fluorescence=False, + reference=False, + location="./test.xdi", + original_filename="standard.xdi", + ) + xas_data2 = XASStandardData( + emission=False, + transmission=True, + fluorescence=False, + reference=False, + location="./test.xdi", + original_filename="standard.xdi", + ) + + session.add(xas_data1) + session.add(xas_data2) + + session.commit() + session.refresh(xas_data1) + session.refresh(xas_data2) + + session.add( + XASStandard( + submitter_id=2, + submission_date=datetime.min, + collection_date=datetime.min, + doi="doi", + citation="citation", + element_z=1, + edge_id=1, + sample_name="sample", + sample_prep="pellet", + sample_comp="H", + beamline_id=1, + licence=LicenceType.cc_0, + id=1, + data_id=xas_data1.id, + reviewer_id=1, + reviewer_comments="good", + review_status=ReviewStatus.approved, + ) + ) + + session.add( + XASStandard( + submitter_id=2, + submission_date=datetime.min, + collection_date=datetime.min, + doi="doi", + citation="citation", + element_z=1, + edge_id=1, + sample_name="sample", + sample_prep="pellet", + sample_comp="He", + beamline_id=1, + licence=LicenceType.cc_0, + id=2, + data_id=xas_data1.id, + reviewer_id=None, + reviewer_comments=None, + review_status=ReviewStatus.pending, + ) + ) + + session.commit()