diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 00000000..26f13f99 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,3 @@ +[pytest] +markers = + db: marks tests as database tests. requires a database container and may be slow. diff --git a/src/mainframe/custom_exceptions.py b/src/mainframe/custom_exceptions.py index e2ab09fe..bd516eff 100644 --- a/src/mainframe/custom_exceptions.py +++ b/src/mainframe/custom_exceptions.py @@ -1,4 +1,17 @@ from fastapi import HTTPException, status +from dataclasses import dataclass + + +@dataclass +class PackageNotFound(Exception): + name: str + version: str + + +@dataclass +class PackageAlreadyReported(Exception): + name: str + reported_version: str class BadCredentialsException(HTTPException): diff --git a/src/mainframe/database.py b/src/mainframe/database.py index 1198204b..4935e736 100644 --- a/src/mainframe/database.py +++ b/src/mainframe/database.py @@ -1,9 +1,15 @@ -from typing import Generator +from collections.abc import Sequence +import datetime as dt +from functools import cache +from typing import Generator, Optional -from sqlalchemy import create_engine -from sqlalchemy.orm import Session, sessionmaker +from sqlalchemy import create_engine, select +from sqlalchemy import orm +from sqlalchemy.orm import Session, joinedload, sessionmaker from mainframe.constants import mainframe_settings +from mainframe.models.orm import Scan +from typing import Protocol # pool_size and max_overflow are set to their default values. There is never # enough load to justify increasing them. @@ -21,3 +27,80 @@ def get_db() -> Generator[Session, None, None]: yield session finally: session.close() + + +class StorageProtocol(Protocol): + def lookup_packages( + self, name: Optional[str] = None, version: Optional[str] = None, since: Optional[dt.datetime] = None + ) -> Sequence[Scan]: + """ + Lookup information on scanned packages based on name, version, or time + scanned. If multiple packages are returned, they are ordered with the most + recently queued package first. + + Args: + since: A int representing a Unix timestamp representing when to begin the search from. + name: The name of the package. + version: The version of the package. + session: DB session. + + Exceptions: + ValueError: Invalid parameter combination was passed. See below. + + Returns: + Sequence of `Scan`s, representing the results of the query + + Only certain combinations of parameters are allowed. A query is valid if any of the following combinations are used: + - `name` and `version`: Return the package with name `name` and version `version`, if it exists. + - `name` and `since`: Find all packages with name `name` since `since`. + - `since`: Find all packages since `since`. + - `name`: Find all packages with name `name`. + All other combinations are disallowed. + + In more formal terms, a query is valid + iff `((name and not since) or (not version and since))` + where a given variable name means that query parameter was passed. Equivalently, a request is invalid + iff `(not (name or since) or (version and since))` + """ + ... + + def mark_reported(self, *, scan: Scan, subject: str) -> None: + """Mark the given `Scan` record as reported by `subject`.""" + ... + + +class DatabaseStorage(StorageProtocol): + def __init__(self, sessionmaker: orm.sessionmaker[Session]): + self.sessionmaker = sessionmaker + + def get_session(self) -> Session: + return self.sessionmaker() + + def lookup_packages( + self, name: Optional[str] = None, version: Optional[str] = None, since: Optional[dt.datetime] = None + ) -> Sequence[Scan]: + query = ( + select(Scan).order_by(Scan.queued_at.desc()).options(joinedload(Scan.rules), joinedload(Scan.download_urls)) + ) + + if name: + query = query.where(Scan.name == name) + if version: + query = query.where(Scan.version == version) + if since: + query = query.where(Scan.finished_at >= since) + + session = self.get_session() + with session, session.begin(): + return session.scalars(query).unique().all() + + def mark_reported(self, *, scan: Scan, subject: str) -> None: + session = self.get_session() + with session, session.begin(): + scan.reported_by = subject + scan.reported_at = dt.datetime.now() + + +@cache +def get_storage() -> DatabaseStorage: + return DatabaseStorage(sessionmaker) diff --git a/src/mainframe/endpoints/report.py b/src/mainframe/endpoints/report.py index 1a0e2e53..40eee632 100644 --- a/src/mainframe/endpoints/report.py +++ b/src/mainframe/endpoints/report.py @@ -1,15 +1,14 @@ -import datetime as dt +from collections.abc import Sequence from typing import Annotated, Optional import httpx import structlog from fastapi import APIRouter, Depends, HTTPException from fastapi.encoders import jsonable_encoder -from sqlalchemy import select -from sqlalchemy.orm import Session, joinedload from mainframe.constants import mainframe_settings -from mainframe.database import get_db +from mainframe.custom_exceptions import PackageNotFound, PackageAlreadyReported +from mainframe.database import StorageProtocol, get_storage from mainframe.dependencies import get_httpx_client, validate_token from mainframe.json_web_token import AuthenticationData from mainframe.models.orm import Scan @@ -28,59 +27,34 @@ router = APIRouter(tags=["report"]) -def _lookup_package(name: str, version: str, session: Session) -> Scan: +def validate_package(name: str, version: str, scans: Sequence[Scan]) -> Scan: """ Checks if the package is valid according to our database. + A package is considered valid if there exists a scan with the given name + and version, and that no other versions have been reported. + + Arguments: + name: The name of the package to validate + version: The version of the package to validate + scans: The sequence of Scan records in the database where name=name Returns: - True if the package exists in the database. + `Scan`: The validated `Scan` object Raises: - HTTPException: 404 Not Found if the name was not found in the database, - or the specified name and version was not found in the database. 409 - Conflict if another version of the same package has already been - reported. + PackageNotFound: The given name and version combination was not found + PackageAlreadyReported: The package was already reported """ - log = logger.bind(package={"name": name, "version": version}) - - query = select(Scan).where(Scan.name == name).options(joinedload(Scan.rules)) - with session.begin(): - scans = session.scalars(query).unique().all() - - if not scans: - error = HTTPException(404, detail=f"No records for package `{name}` were found in the database") - log.error( - f"No records for package {name} found in database", error_message=error.detail, tag="package_not_found_db" - ) - raise error - for scan in scans: if scan.reported_at is not None: - error = HTTPException( - 409, - detail=( - f"Only one version of a package may be reported at a time. " - f"(`{scan.name}@{scan.version}` was already reported)" - ), - ) - log.error( - "Only one version of a package allowed to be reported at a time", - error_message=error.detail, - tag="multiple_versions_prohibited", - ) - raise error - - with session.begin(): - scan = session.scalar(query.where(Scan.version == version)) - if scan is None: - error = HTTPException( - 404, detail=f"Package `{name}` has records in the database, but none with version `{version}`" - ) - log.error(f"No version {version} for package {name} in database", tag="invalid_version") - raise error + raise PackageAlreadyReported(name=scan.name, reported_version=scan.version) + + for scan in scans: + if (scan.name, scan.version) == (name, version): + return scan - return scan + raise PackageNotFound(name=name, version=version) def _validate_inspector_url(name: str, version: str, body_url: Optional[str], scan_url: Optional[str]) -> str: @@ -127,7 +101,7 @@ def _validate_pypi(name: str, version: str, http_client: httpx.Client): ) def report_package( body: ReportPackageBody, - session: Annotated[Session, Depends(get_db)], + database: Annotated[StorageProtocol, Depends(get_storage)], auth: Annotated[AuthenticationData, Depends(validate_token)], httpx_client: Annotated[httpx.Client, Depends(get_httpx_client)], ): @@ -152,7 +126,24 @@ def report_package( log = logger.bind(package={"name": name, "version": version}) # Check our database first to avoid unnecessarily using PyPI API. - scan = _lookup_package(name, version, session) + try: + scans = database.lookup_packages(name) + scan = validate_package(name, version, scans) + except PackageNotFound as e: + detail = f"No records for package `{e.name} v{e.version}` were found in the database" + error = HTTPException(404, detail=detail) + log.error(detail, error_message=detail, tag="package_not_found_db") + + raise error + except PackageAlreadyReported as e: + detail = ( + f"Only one version of a package may be reported at a time " + f"(`{e.name}@{e.reported_version}` was already reported)" + ) + error = HTTPException(409, detail=detail) + log.error(detail, error_message=error.detail, tag="multiple_versions_prohibited") + + raise error inspector_url = _validate_inspector_url(name, version, body.inspector_url, scan.inspector_url) # If execution reaches here, we must have found a matching scan in our @@ -170,11 +161,7 @@ def report_package( httpx_client.post(f"{mainframe_settings.reporter_url}/report/{name}", json=jsonable_encoder(report)) - with session.begin(): - scan.reported_by = auth.subject - scan.reported_at = dt.datetime.now(dt.timezone.utc) - - session.close() + database.mark_reported(scan=scan, subject=auth.subject) log.info( "Sent report", diff --git a/tests/conftest.py b/tests/conftest.py index ea14071a..d2abab88 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,8 @@ +from collections.abc import Sequence, Generator import logging from copy import deepcopy from datetime import datetime, timedelta -from typing import Generator +from typing import Optional from unittest.mock import MagicMock import httpx @@ -12,6 +13,7 @@ from sqlalchemy import Engine, create_engine, text from sqlalchemy.orm import Session, sessionmaker +from mainframe.database import DatabaseStorage, StorageProtocol from mainframe.json_web_token import AuthenticationData from mainframe.models.orm import Base, Scan from mainframe.rules import Rules @@ -22,6 +24,39 @@ logger = logging.getLogger(__file__) +class MockDatabase(StorageProtocol): + def __init__(self) -> None: + self.db: list[Scan] = [] + + def add(self, scan: Scan) -> None: + self.db.append(scan) + + def lookup_packages( + self, name: Optional[str] = None, version: Optional[str] = None, since: Optional[datetime] = None + ) -> Sequence[Scan]: + v: list[Scan] = [] + for scan in self.db: + if ( + (scan.name == name) + or (scan.version == version) + or (scan.queued_at and since and scan.queued_at >= since) + ): + v.append(scan) + + return v + + def mark_reported(self, *, scan: Scan, subject: str) -> None: + for s in self.db: + if s.scan_id == scan.scan_id: + scan.reported_by = subject + scan.reported_at = datetime.now() + + +@pytest.fixture +def mock_database() -> MockDatabase: + return MockDatabase() + + @pytest.fixture(scope="session") def sm(engine: Engine) -> sessionmaker[Session]: return sessionmaker(bind=engine, expire_on_commit=False, autobegin=False) @@ -50,12 +85,26 @@ def engine(superuser_engine: Engine) -> Engine: return create_engine("postgresql+psycopg2://dragonfly:postgres@db:5432/dragonfly", pool_size=5, max_overflow=10) +@pytest.fixture +def storage( + superuser_engine: Engine, test_data: list[Scan], sm: sessionmaker[Session] +) -> Generator[DatabaseStorage, None, None]: + Base.metadata.drop_all(superuser_engine) + Base.metadata.create_all(superuser_engine) + with sm() as s, s.begin(): + s.add_all(deepcopy(test_data)) + + yield DatabaseStorage(sm) + + Base.metadata.drop_all(superuser_engine) + + @pytest.fixture(params=data, scope="session") def test_data(request: pytest.FixtureRequest) -> list[Scan]: return request.param -@pytest.fixture(autouse=True) +@pytest.fixture def db_session( superuser_engine: Engine, test_data: list[Scan], sm: sessionmaker[Session] ) -> Generator[Session, None, None]: diff --git a/tests/test_database.py b/tests/test_database.py new file mode 100644 index 00000000..38f5408d --- /dev/null +++ b/tests/test_database.py @@ -0,0 +1,241 @@ +from copy import deepcopy +from datetime import datetime, timedelta +from typing import Optional +import pytest +from sqlalchemy import select +from mainframe.database import DatabaseStorage +from mainframe.models.orm import Scan, Status + + +@pytest.mark.db +def test_mark_reported(storage: DatabaseStorage): + scan = Scan( + name="package1", + version="1.0.0", + status=Status.FINISHED, + queued_by="remmy", + queued_at=datetime.now(), + ) + + session = storage.get_session() + with session.begin(): + session.add(scan) + + storage.mark_reported(scan=scan, subject="remmy") + + query = select(Scan).where(Scan.name == "package1").where(Scan.version == "1.0.0") + actual = session.scalar(query) + assert actual is not None + assert actual.reported_by == "remmy" + assert actual.reported_at is not None + + +@pytest.mark.db +@pytest.mark.parametrize( + "scans,spec,expected", + [ + ( + [Scan(name="package1", version="1.0.0", status=Status.QUEUED, queued_by="remmy", queued_at=datetime.now())], + ("package1", None, None), + [("package1", "1.0.0")], + ), + ( + [Scan(name="package1", version="1.0.0", status=Status.QUEUED, queued_by="remmy", queued_at=datetime.now())], + ("package1", "1.0.0", None), + [("package1", "1.0.0")], + ), + ( + [ + Scan( + name="package1", version="1.0.0", status=Status.QUEUED, queued_by="remmy", queued_at=datetime.now() + ), + Scan( + name="package1", version="1.0.1", status=Status.QUEUED, queued_by="remmy", queued_at=datetime.now() + ), + ], + ("package1", None, None), + [("package1", "1.0.0"), ("package1", "1.0.1")], + ), + ( + [ + Scan( + name="package1", version="1.0.0", status=Status.QUEUED, queued_by="remmy", queued_at=datetime.now() + ), + Scan( + name="package1", version="1.0.1", status=Status.QUEUED, queued_by="remmy", queued_at=datetime.now() + ), + ], + ("package1", "1.0.1", None), + [("package1", "1.0.1")], + ), + ( + [ + Scan( + name="package1", + version="1.0.0", + status=Status.FINISHED, + queued_by="remmy", + queued_at=datetime.now() - timedelta(seconds=10), + finished_at=datetime.now() - timedelta(seconds=5), + ), + Scan( + name="package1", + version="1.0.1", + status=Status.FINISHED, + queued_by="remmy", + queued_at=datetime.now() - timedelta(seconds=10), + finished_at=datetime.now() - timedelta(seconds=2), + ), + ], + ("package1", None, 0), + [("package1", "1.0.0"), ("package1", "1.0.1")], + ), + ( + [ + Scan( + name="package1", + version="1.0.0", + status=Status.FINISHED, + queued_by="remmy", + queued_at=datetime.now() - timedelta(seconds=10), + finished_at=datetime.now() - timedelta(seconds=5), + ), + Scan( + name="package1", + version="1.0.1", + status=Status.FINISHED, + queued_by="remmy", + queued_at=datetime.now() - timedelta(seconds=10), + finished_at=datetime.now() - timedelta(seconds=2), + ), + ], + ("package1", None, datetime.now() - timedelta(seconds=4)), + [("package1", "1.0.1")], + ), + # we must use a static time for this test here because it can be flaky otherwise + ( + [ + Scan( + name="package1", + version="1.0.0", + status=Status.FINISHED, + queued_by="remmy", + queued_at=datetime.now() - timedelta(seconds=10), + finished_at=datetime(2024, 10, 4, 2, 4) - timedelta(seconds=5), + ), + Scan( + name="package1", + version="1.0.1", + status=Status.FINISHED, + queued_by="remmy", + queued_at=datetime.now() - timedelta(seconds=10), + finished_at=datetime(2024, 10, 4, 2, 4) - timedelta(seconds=2), + ), + ], + ("package1", None, datetime(2024, 10, 4, 2, 4) - timedelta(seconds=2)), + [("package1", "1.0.1")], + ), + ( + [ + Scan( + name="package1", + version="1.0.0", + status=Status.FINISHED, + queued_by="remmy", + queued_at=datetime.now() - timedelta(seconds=10), + finished_at=datetime.now() - timedelta(seconds=5), + ), + Scan( + name="package1", + version="1.0.1", + status=Status.FINISHED, + queued_by="remmy", + queued_at=datetime.now() - timedelta(seconds=10), + finished_at=datetime.now() - timedelta(seconds=2), + ), + ], + ("package1", "1.0.0", datetime.now() - timedelta(seconds=2)), + [], + ), + ( + [ + Scan( + name="package1", + version="1.0.0", + status=Status.FINISHED, + queued_by="remmy", + queued_at=datetime.now() - timedelta(seconds=10), + finished_at=datetime.now() - timedelta(seconds=5), + ), + Scan( + name="package1", + version="1.0.1", + status=Status.FINISHED, + queued_by="remmy", + queued_at=datetime.now() - timedelta(seconds=10), + finished_at=datetime.now() - timedelta(seconds=2), + ), + ], + ("package1", None, datetime.now() - timedelta(seconds=1)), + [], + ), + ( + [ + Scan( + name="package1", + version="1.0.0", + status=Status.FINISHED, + queued_by="remmy", + queued_at=datetime.now() - timedelta(seconds=10), + finished_at=datetime.now() - timedelta(seconds=5), + ), + Scan( + name="package1", + version="1.0.1", + status=Status.FINISHED, + queued_by="remmy", + queued_at=datetime.now() - timedelta(seconds=10), + finished_at=datetime.now() - timedelta(seconds=2), + ), + ], + ("package2", None, None), + [], + ), + ( + [ + Scan( + name="package1", + version="1.0.0", + status=Status.FINISHED, + queued_by="remmy", + queued_at=datetime.now() - timedelta(seconds=10), + finished_at=datetime.now() - timedelta(seconds=5), + ), + Scan( + name="package1", + version="1.0.1", + status=Status.FINISHED, + queued_by="remmy", + queued_at=datetime.now() - timedelta(seconds=10), + finished_at=datetime.now() - timedelta(seconds=2), + ), + ], + ("package1", "1.0.2", None), + [], + ), + ], +) +def test_lookup_packages( + storage: DatabaseStorage, + scans: list[Scan], + spec: tuple[Optional[str], Optional[str], Optional[datetime]], + expected: list[tuple[str, str]], +): + session = storage.get_session() + with session, session.begin(): + session.add_all(deepcopy(scans)) + + name, version, since = spec + results = storage.lookup_packages(name=name, version=version, since=since) + + assert sorted((s.name, s.version) for s in results) == sorted(expected) diff --git a/tests/test_report.py b/tests/test_report.py index c163ec88..20bc9239 100644 --- a/tests/test_report.py +++ b/tests/test_report.py @@ -1,5 +1,4 @@ from datetime import datetime, timedelta -from copy import deepcopy from typing import Optional from unittest.mock import MagicMock @@ -7,12 +6,9 @@ import pytest from fastapi import HTTPException from fastapi.encoders import jsonable_encoder -from sqlalchemy import select -from sqlalchemy.orm import Session, sessionmaker -from mainframe.endpoints.report import ( - _lookup_package, # pyright: ignore [reportPrivateUsage] -) +from mainframe.custom_exceptions import PackageAlreadyReported, PackageNotFound +from mainframe.endpoints.report import validate_package from mainframe.endpoints.report import ( _validate_inspector_url, # pyright: ignore [reportPrivateUsage] ) @@ -27,26 +23,86 @@ ObservationReport, ReportPackageBody, ) +from tests.conftest import MockDatabase -def test_report( - sm: sessionmaker[Session], - db_session: Session, - auth: AuthenticationData, -): +def test_validate_package(): + scan1 = Scan( + name="package1", + version="1.0.0", + status=Status.FINISHED, + queued_by="remmy", + queued_at=datetime.now(), + reported_at=None, + ) + + assert validate_package("package1", "1.0.0", [scan1]) == scan1 + + +def test_validate_package_not_found(): + scan1 = Scan( + name="package1", + version="1.0.0", + status=Status.FINISHED, + queued_by="remmy", + queued_at=datetime.now(), + reported_at=None, + ) + + with pytest.raises(PackageNotFound): + validate_package("package2", "1.0.0", [scan1]) + + +def test_validate_package_already_reported(): + scan1 = Scan( + name="package1", + version="1.0.0", + status=Status.FINISHED, + queued_by="remmy", + queued_at=datetime.now(), + reported_at=None, + ) + scan2 = Scan( + name="package1", + version="1.0.1", + status=Status.FINISHED, + queued_by="remmy", + queued_at=datetime.now(), + reported_at=datetime.now(), + ) + + with pytest.raises(PackageAlreadyReported) as e: + validate_package("package1", "1.0.0", [scan1, scan2]) + + assert (e.value.name, e.value.reported_version) == ("package1", "1.0.1") + + +def test_report_package_not_on_pypi(): + mock_httpx_client = MagicMock(spec=httpx.Client) + mock_httpx_client.configure_mock(**{"get.return_value.status_code": 404}) + + with pytest.raises(HTTPException) as e: + _validate_pypi("c", "1.0.0", mock_httpx_client) + + assert e.value.status_code == 404 + + +def test_report_package_not_found(auth: AuthenticationData, mock_database: MockDatabase): body = ReportPackageBody( - name="c", + name="this-package-does-not-exist", version="1.0.0", inspector_url=None, additional_information="this package is bad", ) - report = ObservationReport( - kind=ObservationKind.Malware, - summary="this package is bad", - inspector_url="test inspector url", - extra=dict(yara_rules=["rule 1", "rule 2"]), - ) + with pytest.raises(HTTPException) as e: + report_package(body, mock_database, auth, MagicMock()) + + assert e.value.status_code == 404 + + +@pytest.mark.parametrize("version", ["1.0.0", "1.0.1"]) +def test_report_package_already_reported(auth: AuthenticationData, mock_database: MockDatabase, version: str): scan = Scan( name="c", version="1.0.0", @@ -59,47 +115,30 @@ def test_report( queued_by="remmy", pending_at=datetime.now() - timedelta(seconds=30), pending_by="remmy", - finished_at=datetime.now(), + finished_at=datetime.now() - timedelta(seconds=15), finished_by="remmy", - reported_at=None, - reported_by=None, + reported_at=datetime.now(), + reported_by="fishy", fail_reason=None, commit_hash="test commit hash", ) - with db_session.begin(): - db_session.add(scan) - - mock_httpx_client = MagicMock() - - report_package(body, sm(), auth, mock_httpx_client) - - mock_httpx_client.post.assert_called_once_with("/report/c", json=jsonable_encoder(report)) - - with sm() as sess, sess.begin(): - s = sess.scalar(select(Scan).where(Scan.name == "c").where(Scan.version == "1.0.0")) + mock_database.add(scan) - assert s is not None - assert s.reported_by == auth.subject - assert s.reported_at is not None - - -def test_report_package_not_on_pypi(): - mock_httpx_client = MagicMock(spec=httpx.Client) - mock_httpx_client.configure_mock(**{"get.return_value.status_code": 404}) + body = ReportPackageBody( + name="c", + version=version, + inspector_url=None, + additional_information="this package is bad", + ) with pytest.raises(HTTPException) as e: - _validate_pypi("c", "1.0.0", mock_httpx_client) - assert e.value.status_code == 404 + report_package(body, mock_database, auth, MagicMock()) - -def test_report_unscanned_package(db_session: Session): - with pytest.raises(HTTPException) as e: - _lookup_package("c", "1.0.0", db_session) - assert e.value.status_code == 404 + assert e.value.status_code == 409 -def test_report_invalid_version(db_session: Session): +def test_report(auth: AuthenticationData, mock_database: MockDatabase): scan = Scan( name="c", version="1.0.0", @@ -112,19 +151,38 @@ def test_report_invalid_version(db_session: Session): queued_by="remmy", pending_at=datetime.now() - timedelta(seconds=30), pending_by="remmy", - finished_at=datetime.now() - timedelta(seconds=10), + finished_at=datetime.now(), finished_by="remmy", reported_at=None, - reported_by="remmy", + reported_by=None, fail_reason=None, commit_hash="test commit hash", ) - with db_session.begin(): - db_session.add(scan) - with pytest.raises(HTTPException) as e: - _lookup_package("c", "2.0.0", db_session) - assert e.value.status_code == 404 + mock_database.add(scan) + + body = ReportPackageBody( + name="c", + version="1.0.0", + inspector_url=None, + additional_information="this package is bad", + ) + + expected = ObservationReport( + kind=ObservationKind.Malware, + summary="this package is bad", + inspector_url="test inspector url", + extra=dict(yara_rules=["rule 1", "rule 2"]), + ) + + mock_httpx_client = MagicMock() + + report_package(body, mock_database, auth, mock_httpx_client) + + mock_httpx_client.post.assert_called_once_with("/report/c", json=jsonable_encoder(expected)) + + assert scan.reported_by is auth.subject + assert scan.reported_at is not None def test_report_missing_inspector_url(): @@ -145,9 +203,9 @@ def test_report_inspector_url(body_url: Optional[str], scan_url: Optional[str]): @pytest.mark.parametrize( - ("scans", "name", "version", "expected_status_code"), + ("scans", "name", "version", "expected_exception"), [ - ([], "a", "1.0.0", 404), + ([], "a", "1.0.0", PackageNotFound), ( [ Scan( @@ -191,7 +249,7 @@ def test_report_inspector_url(body_url: Optional[str], scan_url: Optional[str]): ], "c", "1.0.1", - 409, + PackageAlreadyReported, ), ( [ @@ -236,45 +294,12 @@ def test_report_inspector_url(body_url: Optional[str], scan_url: Optional[str]): ], "c", "2.0.0", - 409, + PackageAlreadyReported, ), ], ) def test_report_lookup_package_validation( - db_session: Session, scans: list[Scan], name: str, version: str, expected_status_code: int + scans: list[Scan], name: str, version: str, expected_exception: type[Exception] ): - with db_session.begin(): - db_session.add_all(deepcopy(scans)) - - with pytest.raises(HTTPException) as e: - _lookup_package(name, version, db_session) - assert e.value.status_code == expected_status_code - - -def test_report_lookup_package(db_session: Session): - scan = Scan( - name="c", - version="1.0.0", - status=Status.FINISHED, - score=0, - inspector_url=None, - rules=[], - download_urls=[], - queued_at=datetime.now() - timedelta(seconds=60), - queued_by="remmy", - pending_at=datetime.now() - timedelta(seconds=30), - pending_by="remmy", - finished_at=datetime.now() - timedelta(seconds=10), - finished_by="remmy", - reported_at=None, - reported_by=None, - fail_reason=None, - commit_hash="test commit hash", - ) - - with db_session.begin(): - db_session.add(scan) - - res = _lookup_package("c", "1.0.0", db_session) - - assert res == scan + with pytest.raises(expected_exception): + validate_package(name, version, scans)