Skip to content

Commit

Permalink
add more unit test (#13)
Browse files Browse the repository at this point in the history
* add more unit test

* missing ruff
  • Loading branch information
jacobfilik authored Jun 24, 2024
1 parent c0a4194 commit 3272746
Show file tree
Hide file tree
Showing 9 changed files with 377 additions and 118 deletions.
35 changes: 22 additions & 13 deletions xas-standards-api/src/xas_standards_api/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def read_standards_page(
session: Session,
element: str | None = None,
) -> CursorPage[XASStandardResponse]:

statement = select(XASStandard).where(
XASStandard.review_status == ReviewStatus.approved
)
Expand Down Expand Up @@ -89,7 +88,10 @@ def get_metadata(session):

def get_standard(session, id) -> XASStandard:
standard = session.get(XASStandard, id)

if standard:
if standard.review_status != ReviewStatus.approved:
raise HTTPException(status_code=401, detail="Standard not available")
return standard
else:
raise HTTPException(status_code=404, detail=f"No standard with id={id}")
Expand Down Expand Up @@ -132,7 +134,6 @@ def select_or_create_person(session, identifier):


def add_new_standard(session, file1, xs_input: XASStandardInput, additional_files):

tmp_filename = pvc_location + str(uuid.uuid4())

with open(tmp_filename, "wb") as ntf:
Expand Down Expand Up @@ -166,7 +167,6 @@ def add_new_standard(session, file1, xs_input: XASStandardInput, additional_file


def get_filepath(session, id):

standard = session.get(XASStandard, id)
if not standard:
raise HTTPException(status_code=404, detail=f"No standard with id={id}")
Expand All @@ -186,7 +186,10 @@ def get_file(session, id):
return FileResponse(xdi_location)


def get_file_as_text(session, id):
def get_file_as_text(session, id, user_id):
# only admins can see original file
is_admin_user(session, user_id)

xdi_location = get_filepath(session, id)
with open(xdi_location) as fh:
file = fh.read()
Expand All @@ -195,7 +198,6 @@ def get_file_as_text(session, id):


def get_norm(energy, group, type):

if type in group:
r = group[type]
tr = set_xafsGroup(None)
Expand All @@ -208,7 +210,6 @@ def get_norm(energy, group, type):


def get_data(session, id):

xdi_location = get_filepath(session, id)

xdi_data = xdi.read_xdi(xdi_location)
Expand Down Expand Up @@ -237,17 +238,25 @@ def get_standards_admin(
session: Session,
user_id: str,
):
is_admin_user(session, user_id)

statement = select(XASStandard).where(
XASStandard.review_status == ReviewStatus.pending
)

return paginate(session, statement.order_by(XASStandard.id))


def is_admin_user(session: Session, user_id: str):
statement = select(Person).where(Person.identifier == user_id)
person = session.exec(statement).first()

if person is None or not person.admin:
raise HTTPException(status_code=401, detail=f"No standard with id={user_id}")
if person is None:
raise HTTPException(
status_code=401, detail=f"No person associated with id {user_id}"
)

if not person.admin:
raise HTTPException(status_code=401, detail=f"User {user_id} not admin")

statement = select(XASStandard).where(
XASStandard.review_status == ReviewStatus.pending
)

return paginate(session, statement.order_by(XASStandard.id))
return True
19 changes: 15 additions & 4 deletions xas-standards-api/src/xas_standards_api/routers/admin.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,36 @@
from typing import Optional

from fastapi import APIRouter, Depends
from fastapi_pagination.cursor import CursorPage
from sqlmodel import Session

from ..auth import get_current_user
from ..crud import get_file_as_text, get_standards_admin
from ..crud import get_data, get_file, get_file_as_text, get_standards_admin
from ..database import get_session
from ..models.response_models import AdminXASStandardResponse

router = APIRouter()


@router.get("/api/admin/data/{id}")
async def read_admin_data(id: int, session: Session = Depends(get_session)):
async def read_admin_data(
id: int,
format: Optional[str] = "",
session: Session = Depends(get_session),
user_id: str = Depends(get_current_user),
):
if format == "download":
return get_file(session, id)

return get_file_as_text(session, id)
if format == "json":
return get_data(session, id)

return get_file_as_text(session, id, user_id)


@router.get("/api/admin/standards")
def read_standards_admin(
session: Session = Depends(get_session),
user_id: str = Depends(get_current_user),
) -> CursorPage[AdminXASStandardResponse]:

return get_standards_admin(session, user_id)
8 changes: 6 additions & 2 deletions xas-standards-api/src/xas_standards_api/routers/open.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Optional

from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, HTTPException
from fastapi_pagination.cursor import CursorPage
from sqlmodel import Session

from ..crud import get_data, get_file, get_metadata, get_standard, read_standards_page
from ..database import get_session
from ..models.models import ReviewStatus
from ..models.response_models import (
MetadataResponse,
XASStandardResponse,
Expand All @@ -31,14 +32,17 @@ def read_standards(
session: Session = Depends(get_session),
element: str | None = None,
) -> CursorPage[XASStandardResponse]:

return read_standards_page(session, element)


@router.get("/api/data/{id}")
async def read_data(
id: int, format: Optional[str] = "json", session: Session = Depends(get_session)
):
standard = get_standard(session, id)

if standard.review_status != ReviewStatus.approved:
raise HTTPException(status_code=401, detail="Standard data not available")

if format == "xdi":
return get_file(session, id)
Expand Down
69 changes: 69 additions & 0 deletions xas-standards-api/tests/test_admin_router.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from fastapi.testclient import TestClient
from sqlmodel import Session, SQLModel, create_engine
from sqlmodel.pool import StaticPool

from utils import build_test_database
from xas_standards_api.app import app
from xas_standards_api.auth import get_current_user
from xas_standards_api.database import get_session
from xas_standards_api.models.response_models import AdminXASStandardResponse


def test_admin_read_permissions():
engine = create_engine(
"sqlite://",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
SQLModel.metadata.create_all(engine)

with Session(engine) as session:
build_test_database(session)

def get_session_override():
return session

def get_ordinary_user():
return "user"

def get_admin_user():
return "admin"

client = TestClient(app)

# first try with ordinary user
app.dependency_overrides[get_session] = get_session_override
app.dependency_overrides[get_current_user] = get_ordinary_user

response = client.get("/api/admin/standards")

assert response.status_code == 401

# check cant get data
response = client.get("/api/admin/data/2")
assert response.status_code == 401

# check cant get data from open endpoint
response = client.get("/api/data/2")
assert response.status_code == 401

# now try admin user
app.dependency_overrides.clear()
app.dependency_overrides[get_session] = get_session_override
app.dependency_overrides[get_current_user] = get_admin_user

response = client.get("/api/admin/standards")
r = response.json()

# check response is paginated, containing 1 item
assert "items" in r
assert len(r["items"]) == 1

# check its correct response and contains the submitter identifier
axassr = AdminXASStandardResponse.model_validate(r["items"][0])
assert axassr.submitter.identifier == "user"

# check can get data
response = client.get("/api/admin/data/2")

assert response.text.startswith("# XDI")
102 changes: 5 additions & 97 deletions xas-standards-api/tests/test_app.py
Original file line number Diff line number Diff line change
@@ -1,102 +1,10 @@
from fastapi.testclient import TestClient
from sqlmodel import Session, SQLModel, create_engine
from sqlmodel.pool import StaticPool

from xas_standards_api.app import app
from xas_standards_api.auth import get_current_user
from xas_standards_api.database import get_session
from xas_standards_api.models.models import Beamline, Edge, Element, Facility, Person
from xas_standards_api.models.response_models import MetadataResponse

client = TestClient(app)


def test_read_item():
engine = create_engine(
"sqlite://",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
SQLModel.metadata.create_all(engine)

with Session(engine) as session:

session.add(Element(name="Hydrogen", z=1, symbol="H"))
session.add(Edge(name="K", id=1, level="sp"))
session.add(
Facility(
id=1,
name="synchrotron",
notes="a place",
fullname="a synchrotron",
city="somewhere",
region="someplace",
laboratory="a lab",
country="somecountry",
)
)

session.add(
Beamline(
facility_id=1,
id=1,
name="my beamline",
notes="a beamline",
xray_source="BM",
)
)
session.commit()

def get_session_override():
return session

app.dependency_overrides[get_session] = get_session_override

client = TestClient(app)

response = client.get("/api/metadata/")
app.dependency_overrides.clear()

print(response)

mr = MetadataResponse.model_validate(response.json())

assert response.status_code == 200
assert mr.elements[0].symbol == "H"
assert mr.edges[0].name == "K"
assert mr.beamlines[0].name == "my beamline"
assert mr.beamlines[0].facility.name == "synchrotron"


def test_read_person():
engine = create_engine(
"sqlite://",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
SQLModel.metadata.create_all(engine)

with Session(engine) as session:

session.add(Person(id=1, identifier="abc123", admin=False))

session.commit()

def get_session_override():
return session

def get_current_user_override():
return "abc123"

app.dependency_overrides[get_session] = get_session_override
app.dependency_overrides[get_current_user] = get_current_user_override

client = TestClient(app)

response = client.get("/api/user/")
app.dependency_overrides.clear()

r = response.json()

assert r["user"] == "abc123"
assert not r["admin"]
def test_login_redirect():
client = TestClient(app)
response = client.get("/login")
# expect 404 since root is not defined in test
assert response.status_code == 404
4 changes: 2 additions & 2 deletions xas-standards-api/tests/test_crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
from sqlmodel import Session

from xas_standards_api import crud
from xas_standards_api.models.models import XASStandard
from xas_standards_api.models.models import ReviewStatus, XASStandard


def test_get_standard():

mock_session = create_autospec(Session, instance=True)

test_id = 0
result = XASStandard()
result.review_status = ReviewStatus.approved

# Session returns None, i.e. no standard for id
mock_session.get = Mock(return_value=None)
Expand Down
Loading

0 comments on commit 3272746

Please sign in to comment.