diff --git a/.flake8 b/.flake8 index a92b3e1..afc2645 100644 --- a/.flake8 +++ b/.flake8 @@ -11,6 +11,6 @@ exclude = docs/source/conf.py max-line-length = 88 select = C,E,F,W,B,B950 -extend-ignore = E203,E501,E129,W503 +extend-ignore = E203,E501,E129,W503,E701 per-file-ignores = setup.py:F401 diff --git a/conftest.py b/conftest.py index 33713ed..b9bfaf8 100644 --- a/conftest.py +++ b/conftest.py @@ -3,8 +3,10 @@ import logging import typing as ty import tempfile -from logging.handlers import SMTPHandler + +# from logging.handlers import SMTPHandler import pytest +import click.testing from click.testing import CliRunner import xnat4tests # type: ignore[import-untyped] from datetime import datetime @@ -31,11 +33,12 @@ if os.getenv("_PYTEST_RAISE", "0") != "0": @pytest.hookimpl(tryfirst=True) - def pytest_exception_interact(call): - raise call.excinfo.value + def pytest_exception_interact(call: pytest.CallInfo[ty.Any]) -> None: + if call.excinfo is not None: + raise call.excinfo.value @pytest.hookimpl(tryfirst=True) - def pytest_internalerror(excinfo): + def pytest_internalerror(excinfo: pytest.ExceptionInfo[BaseException]) -> None: raise excinfo.value CATCH_CLI_EXCEPTIONS = False @@ -44,28 +47,28 @@ def pytest_internalerror(excinfo): @pytest.fixture -def catch_cli_exceptions(): +def catch_cli_exceptions() -> bool: return CATCH_CLI_EXCEPTIONS @pytest.fixture(scope="session") -def run_prefix(): +def run_prefix() -> str: "A datetime string used to avoid stale data left over from previous tests" return datetime.strftime(datetime.now(), "%Y%m%d%H%M%S") @pytest.fixture(scope="session") -def xnat_repository(): +def xnat_repository() -> None: xnat4tests.start_xnat() @pytest.fixture(scope="session") -def xnat_archive_dir(xnat_repository): - return xnat4tests.Config().xnat_root_dir / "archive" +def xnat_archive_dir(xnat_repository: None) -> Path: + return xnat4tests.Config().xnat_root_dir / "archive" # type: ignore[no-any-return] @pytest.fixture(scope="session") -def tmp_gen_dir(): +def tmp_gen_dir() -> Path: # tmp_gen_dir = Path("~").expanduser() / ".xnat-ingest-work3" # tmp_gen_dir.mkdir(exist_ok=True) # return tmp_gen_dir @@ -73,12 +76,12 @@ def tmp_gen_dir(): @pytest.fixture(scope="session") -def xnat_login(xnat_repository): +def xnat_login(xnat_repository: str) -> ty.Any: return xnat4tests.connect() @pytest.fixture(scope="session") -def xnat_project(xnat_login, run_prefix): +def xnat_project(xnat_login: ty.Any, run_prefix: str) -> ty.Any: project_id = f"INGESTUPLOAD{run_prefix}" with xnat4tests.connect() as xnat_login: xnat_login.put(f"/data/archive/projects/{project_id}") @@ -86,40 +89,46 @@ def xnat_project(xnat_login, run_prefix): @pytest.fixture(scope="session") -def xnat_server(xnat_config): - return xnat_config.xnat_uri +def xnat_server(xnat_config: xnat4tests.Config) -> str: + return xnat_config.xnat_uri # type: ignore[no-any-return] @pytest.fixture(scope="session") -def xnat_config(xnat_repository): +def xnat_config(xnat_repository: str) -> xnat4tests.Config: return xnat4tests.Config() @pytest.fixture -def cli_runner(catch_cli_exceptions): - def invoke(*args, catch_exceptions=catch_cli_exceptions, **kwargs): +def cli_runner(catch_cli_exceptions: bool) -> ty.Callable[..., ty.Any]: + def invoke( + *args: ty.Any, catch_exceptions: bool = catch_cli_exceptions, **kwargs: ty.Any + ) -> click.testing.Result: runner = CliRunner() - result = runner.invoke(*args, catch_exceptions=catch_exceptions, **kwargs) + result = runner.invoke(*args, catch_exceptions=catch_exceptions, **kwargs) # type: ignore[misc] return result return invoke -# Create a custom handler that captures email messages for testing -class TestSMTPHandler(SMTPHandler): - def __init__( - self, mailhost, fromaddr, toaddrs, subject, credentials=None, secure=None - ): - super().__init__(mailhost, fromaddr, toaddrs, subject, credentials, secure) - self.emails = [] # A list to store captured email messages +# # Create a custom handler that captures email messages for testing +# class TestSMTPHandler(SMTPHandler): +# def __init__( +# self, mailhost, fromaddr, toaddrs, subject, credentials=None, secure=None +# ): +# super().__init__(mailhost, fromaddr, toaddrs, subject, credentials, secure) +# self.emails = [] # A list to store captured email messages - def emit(self, record): - # Capture the email message and append it to the list - msg = self.format(record) - self.emails.append(msg) +# def emit(self, record): +# # Capture the email message and append it to the list +# msg = self.format(record) +# self.emails.append(msg) -def get_raw_data_files(out_dir: ty.Optional[Path] = None, **kwargs) -> ty.List[Path]: +def get_raw_data_files( + out_dir: ty.Optional[Path] = None, **kwargs: ty.Any +) -> ty.List[Path]: if out_dir is None: out_dir = Path(tempfile.mkdtemp()) - return get_listmode_data(out_dir, **kwargs) + get_countrate_data(out_dir, **kwargs) + return get_listmode_data(out_dir, skip_unknown=True, **kwargs) + get_countrate_data( # type: ignore[no-any-return] + out_dir, skip_unknown=True, **kwargs + ) diff --git a/pyproject.toml b/pyproject.toml index a23080d..899cab0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,7 @@ readme = "README.md" requires-python = ">=3.8" dependencies = [ "click >=8.1", + "discord", "fileformats-medimage-extras", "pydicom >=2.3.1", "tqdm >=4.64.1", @@ -84,4 +85,19 @@ doctests = true per-file-ignores = ["__init__.py:F401"] max-line-length = 88 select = "C,E,F,W,B,B950" -extend-ignore = ['E203', 'E501', 'E129', "W503"] +extend-ignore = ['E203', 'E501', 'E129', "W503", 'E701'] + + +[tool.mypy] +python_version = "3.10" +ignore_missing_imports = true +strict = true +explicit_package_bases = true +exclude = [ + "tests", + "scripts", + "docs", + "build", + "dist", + "xnat_ingest/_version.py", +] diff --git a/real-tests/usyd_stage.py b/real-tests/usyd_stage.py index 39e4cf6..6caf938 100644 --- a/real-tests/usyd_stage.py +++ b/real-tests/usyd_stage.py @@ -9,15 +9,15 @@ stage, [], env={ - "XNAT_INGEST_STAGE_DICOMS_PATH": "/vol/vmhost/kubernetes/////**/*.IMA", - "XNAT_INGEST_STAGE_DIR": "/vol/vmhost/usyd-data-export/STAGING", - "XNAT_INGEST_STAGE_PROJECT": "ProtocolName", - "XNAT_INGEST_STAGE_SUBJECT": "PatientID", - "XNAT_INGEST_STAGE_VISIT": "AccessionNumber", - "XNAT_INGEST_STAGE_ASSOCIATED": '"/vol/vmhost/usyd-data-export/RAW-DATA-EXPORT/{PatientName.family_name}_{PatientName.given_name}/.ptd","./[^\\.]+.[^\\.]+.[^\\.]+.(?P\\d+).[A-Z]+_(?P[^\\.]+)."', - "XNAT_INGEST_STAGE_DELETE": "0", - "XNAT_INGEST_STAGE_LOGFILE": ",INFO", - "XNAT_INGEST_STAGE_DEIDENTIFY": "1", + "XINGEST_DICOMS_PATH": "/vol/vmhost/kubernetes/////**/*.IMA", + "XINGEST_DIR": "/vol/vmhost/usyd-data-export/STAGING", + "XINGEST_PROJECT": "ProtocolName", + "XINGEST_SUBJECT": "PatientID", + "XINGEST_VISIT": "AccessionNumber", + "XINGEST_ASSOCIATED": '"/vol/vmhost/usyd-data-export/RAW-DATA-EXPORT/{PatientName.family_name}_{PatientName.given_name}/.ptd","./[^\\.]+.[^\\.]+.[^\\.]+.(?P\\d+).[A-Z]+_(?P[^\\.]+)."', + "XINGEST_DELETE": "0", + "XINGEST_LOGFILE": ",INFO", + "XINGEST_DEIDENTIFY": "1", }, catch_exceptions=False, ) diff --git a/real-tests/usyd_transfer.py b/real-tests/usyd_transfer.py index d143b51..baa7b5e 100644 --- a/real-tests/usyd_transfer.py +++ b/real-tests/usyd_transfer.py @@ -9,7 +9,7 @@ transfer, [], env={ - "XNAT_INGEST_STAGE_DIR": "/Users/tclose/Data/testing/staging-test/", + "XINGEST_DIR": "/Users/tclose/Data/testing/staging-test/", "XNAT_INGEST_TRANSFER_LOGFILE": "/Users/tclose/Desktop/test-log.log,INFO", "XNAT_INGEST_TRANSFER_DELETE": "0", }, diff --git a/real-tests/usyd_upload.py b/real-tests/usyd_upload.py index 2a498fa..4a039b2 100644 --- a/real-tests/usyd_upload.py +++ b/real-tests/usyd_upload.py @@ -8,17 +8,17 @@ upload, [], env={ - "XNAT_INGEST_UPLOAD_STAGED": "", - "XNAT_INGEST_UPLOAD_HOST": "https://xnat.sydney.edu.au", - "XNAT_INGEST_UPLOAD_USER": "", - "XNAT_INGEST_UPLOAD_PASS": "", - "XNAT_INGEST_UPLOAD_ALWAYSINCLUDE": "medimage/dicom-series", - "XNAT_INGEST_UPLOAD_STORE_CREDENTIALS": ",", - "XNAT_INGEST_UPLOAD_LOGFILE": ",INFO", - "XNAT_INGEST_UPLOAD_DELETE": "0", - "XNAT_INGEST_UPLOAD_TEMPDIR": "", - "XNAT_INGEST_UPLOAD_REQUIRE_MANIFEST": "1", - "XNAT_INGEST_UPLOAD_CLEANUP_OLDER_THAN": "30", + "XINGEST_STAGED": "", + "XINGEST_HOST": "https://xnat.sydney.edu.au", + "XINGEST_USER": "", + "XINGEST_PASS": "", + "XINGEST_ALWAYSINCLUDE": "medimage/dicom-series", + "XINGEST_STORE_CREDENTIALS": ",", + "XINGEST_LOGFILE": ",INFO", + "XINGEST_DELETE": "0", + "XINGEST_TEMPDIR": "", + "XINGEST_REQUIRE_MANIFEST": "1", + "XINGEST_CLEANUP_OLDER_THAN": "30", }, catch_exceptions=False, ) diff --git a/xnat_ingest/__init__.py b/xnat_ingest/__init__.py index 8dee4bf..26d23ba 100644 --- a/xnat_ingest/__init__.py +++ b/xnat_ingest/__init__.py @@ -1 +1,3 @@ from ._version import __version__ + +__all__ = ["__version__"] diff --git a/xnat_ingest/cli/base.py b/xnat_ingest/cli/base.py index f48ad6d..911b238 100644 --- a/xnat_ingest/cli/base.py +++ b/xnat_ingest/cli/base.py @@ -4,5 +4,5 @@ @click.group(help="Checks and uploads scans exported from scanner consoles to XNAT") @click.version_option(version=__version__) -def cli(): +def cli() -> None: pass diff --git a/xnat_ingest/cli/stage.py b/xnat_ingest/cli/stage.py index a6abeaa..79eba75 100644 --- a/xnat_ingest/cli/stage.py +++ b/xnat_ingest/cli/stage.py @@ -2,21 +2,27 @@ import typing as ty import traceback import click +import datetime +import time import tempfile from tqdm import tqdm +from fileformats.core import FileSet from xnat_ingest.cli.base import cli from xnat_ingest.session import ImagingSession from frametree.xnat import Xnat # type: ignore[import-untyped] from xnat_ingest.utils import ( AssociatedFiles, logger, - LogFile, - LogEmail, - MailServer, + LoggerConfig, XnatLogin, set_logger_handling, ) +PRE_STAGE_NAME_DEFAULT = "PRE-STAGE" +STAGED_NAME_DEFAULT = "STAGED" +INVALID_NAME_DEFAULT = "INVALID" +DEIDENTIFIED_NAME_DEFAULT = "DEIDENTIFIED" + @cli.command( help="""Stages DICOM and associated files found in the input directories into separate @@ -29,38 +35,36 @@ are uploaded to XNAT """, ) -@click.argument("files_path", type=str, envvar="XNAT_INGEST_STAGE_DICOMS_PATH") -@click.argument( - "staging_dir", type=click.Path(path_type=Path), envvar="XNAT_INGEST_STAGE_DIR" -) +@click.argument("files_path", type=str, envvar="XINGEST_DICOMS_PATH") +@click.argument("output_dir", type=click.Path(path_type=Path), envvar="XINGEST_DIR") @click.option( "--datatype", type=str, metavar="", multiple=True, default=["medimage/dicom-series"], - envvar="XNAT_INGEST_STAGE_DATATYPE", + envvar="XINGEST_DATATYPE", help="The datatype of the primary files to to upload", ) @click.option( "--project-field", type=str, default="StudyID", - envvar="XNAT_INGEST_STAGE_PROJECT", + envvar="XINGEST_PROJECT", help=("The keyword of the metadata field to extract the XNAT project ID from "), ) @click.option( "--subject-field", type=str, default="PatientID", - envvar="XNAT_INGEST_STAGE_SUBJECT", + envvar="XINGEST_SUBJECT", help=("The keyword of the metadata field to extract the XNAT subject ID from "), ) @click.option( "--visit-field", type=str, default="AccessionNumber", - envvar="XNAT_INGEST_STAGE_VISIT", + envvar="XINGEST_VISIT", help=( "The keyword of the metadata field to extract the XNAT imaging session ID from " ), @@ -69,7 +73,7 @@ "--session-field", type=str, default=None, - envvar="XNAT_INGEST_STAGE_SESSION", + envvar="XINGEST_SESSION", help=( "The keyword of the metadata field to extract the XNAT imaging session ID from " ), @@ -78,7 +82,7 @@ "--scan-id-field", type=str, default="SeriesNumber", - envvar="XNAT_INGEST_STAGE_SCAN_ID", + envvar="XINGEST_SCAN_ID", help=( "The keyword of the metadata field to extract the XNAT imaging scan ID from " ), @@ -87,7 +91,7 @@ "--scan-desc-field", type=str, default="SeriesDescription", - envvar="XNAT_INGEST_STAGE_SCAN_DESC", + envvar="XINGEST_SCAN_DESC", help=( "The keyword of the metadata field to extract the XNAT imaging scan description from " ), @@ -96,7 +100,7 @@ "--resource-field", type=str, default="ImageType[-1]", - envvar="XNAT_INGEST_STAGE_RESOURCE", + envvar="XINGEST_RESOURCE", help=( "The keyword of the metadata field to extract the XNAT imaging resource ID from " ), @@ -113,7 +117,7 @@ nargs=3, default=None, multiple=True, - envvar="XNAT_INGEST_STAGE_ASSOCIATED", + envvar="XINGEST_ASSOCIATED", metavar=" ", help=( 'The "glob" arg is a glob pattern by which to detect associated files to be ' @@ -135,67 +139,33 @@ @click.option( "--delete/--dont-delete", default=False, - envvar="XNAT_INGEST_STAGE_DELETE", + envvar="XINGEST_DELETE", help="Whether to delete the session directories after they have been uploaded or not", ) @click.option( - "--log-level", - default="info", - type=str, - envvar="XNAT_INGEST_STAGE_LOGLEVEL", - help=("The level of the logging printed to stdout"), -) -@click.option( - "--log-file", - "log_files", - default=None, - type=LogFile.cli_type, + "--logger", + "loggers", multiple=True, - nargs=2, - metavar=" ", - envvar="XNAT_INGEST_STAGE_LOGFILE", - help=( - 'Location to write the output logs to, defaults to "upload-logs" in the ' - "export directory" - ), -) -@click.option( - "--log-email", - "log_emails", - type=LogEmail.cli_type, + type=LoggerConfig.cli_type, + envvar="XINGEST_LOGGERS", nargs=3, - metavar="
", - multiple=True, - envvar="XNAT_INGEST_STAGE_LOGEMAIL", - help=( - "Email(s) to send logs to. When provided in an environment variable, " - "mail and log level are delimited by ',' and separate destinations by ';'" - ), + default=(), + metavar=" ", + help=("Setup handles to capture logs that are generated"), ) @click.option( - "--add-logger", + "--additional-logger", + "additional_loggers", type=str, multiple=True, default=(), - envvar="XNAT_INGEST_UPLOAD_LOGGERS", + envvar="XINGEST_ADDITIONAL_LOGGERS", help=( "The loggers to use for logging. By default just the 'xnat-ingest' logger is used. " "But additional loggers can be included (e.g. 'xnat') can be " "specified here" ), ) -@click.option( - "--mail-server", - type=MailServer.cli_type, - nargs=4, - metavar=" ", - default=None, - envvar="XNAT_INGEST_STAGE_MAILSERVER", - help=( - "the mail server to send logger emails to. When provided in an environment variable, " - "args are delimited by ';'" - ), -) @click.option( "--raise-errors/--dont-raise-errors", default=False, @@ -206,7 +176,7 @@ "--deidentify/--dont-deidentify", default=False, type=bool, - envvar="XNAT_INGEST_STAGE_DEIDENTIFY", + envvar="XINGEST_DEIDENTIFY", help="whether to deidentify the file names and DICOM metadata before staging", ) @click.option( @@ -222,12 +192,64 @@ "--spaces-to-underscores/--no-spaces-to-underscores", default=False, help="Whether to replace spaces with underscores in the filenames of associated files", - envvar="XNAT_INGEST_STAGE_SPACES_TO_UNDERSCORES", + envvar="XINGEST_SPACES_TO_UNDERSCORES", type=bool, ) +@click.option( + "--work-dir", + type=click.Path(path_type=Path), + default=None, + envvar="XINGEST_WORK_DIR", + help=( + "The working directory to use for temporary files. Should be on the same " + "physical disk as the staging directory for optimal performance" + ), +) +@click.option( + "--copy-mode", + type=FileSet.CopyMode, + default=FileSet.CopyMode.hardlink_or_copy, + envvar="XINGEST_COPY_MODE", + help="The method to use for copying files", +) +@click.option( + "--loop", + type=int, + default=None, + envvar="XINGEST_LOOP", + help="Run the staging process continuously every LOOP seconds", +) +@click.option( + "--pre-stage-dir-name", + type=str, + default=PRE_STAGE_NAME_DEFAULT, + envvar="XINGEST_PRE_STAGE_DIR_NAME", + help="The name of the directory to use for pre-staging the files", +) +@click.option( + "--staged-dir-name", + type=str, + default=STAGED_NAME_DEFAULT, + envvar="XINGEST_STAGED_DIR_NAME", + help="The name of the directory to use for staging the files", +) +@click.option( + "--invalid-dir-name", + type=str, + default=INVALID_NAME_DEFAULT, + envvar="XINGEST_INVALID_DIR_NAME", + help="The name of the directory to use for invalid files", +) +@click.option( + "--deidentified-dir-name", + type=str, + default=DEIDENTIFIED_NAME_DEFAULT, + envvar="XINGEST_DEIDENTIFIED_DIR_NAME", + help="The name of the directory to use for deidentified files", +) def stage( files_path: str, - staging_dir: Path, + output_dir: Path, datatype: str, associated_files: ty.List[AssociatedFiles], project_field: str, @@ -239,22 +261,23 @@ def stage( resource_field: str, project_id: str | None, delete: bool, - log_level: str, - log_files: ty.List[LogFile], - log_emails: ty.List[LogEmail], - add_logger: ty.List[str], - mail_server: MailServer, + loggers: ty.List[LoggerConfig], + additional_loggers: ty.List[str], raise_errors: bool, deidentify: bool, xnat_login: XnatLogin, spaces_to_underscores: bool, -): + copy_mode: FileSet.CopyMode, + pre_stage_dir_name: str, + staged_dir_name: str, + invalid_dir_name: str, + deidentified_dir_name: str, + loop: int | None, + work_dir: Path | None = None, +) -> None: set_logger_handling( - log_level=log_level, - log_emails=log_emails, - log_files=log_files, - mail_server=mail_server, - add_logger=add_logger, + logger_configs=loggers, + additional_loggers=additional_loggers, ) if xnat_login: @@ -281,48 +304,89 @@ def stage( logger.info(msg) - sessions = ImagingSession.from_paths( - files_path=files_path, - project_field=project_field, - subject_field=subject_field, - visit_field=visit_field, - session_field=session_field, - scan_id_field=scan_id_field, - scan_desc_field=scan_desc_field, - resource_field=resource_field, - project_id=project_id, - ) + # Create sub-directories of the output directory for the different phases of the + # staging process + prestage_dir = output_dir / pre_stage_dir_name + staged_dir = output_dir / staged_dir_name + invalid_dir = output_dir / invalid_dir_name + prestage_dir.mkdir(parents=True, exist_ok=True) + staged_dir.mkdir(parents=True, exist_ok=True) + invalid_dir.mkdir(parents=True, exist_ok=True) + if deidentify: + deidentified_dir = output_dir / deidentified_dir_name + deidentified_dir.mkdir(parents=True, exist_ok=True) - logger.info("Staging sessions to '%s'", str(staging_dir)) + def do_stage() -> None: + sessions = ImagingSession.from_paths( + files_path=files_path, + project_field=project_field, + subject_field=subject_field, + visit_field=visit_field, + session_field=session_field, + scan_id_field=scan_id_field, + scan_desc_field=scan_desc_field, + resource_field=resource_field, + project_id=project_id, + ) + + logger.info("Staging sessions to '%s'", str(output_dir)) - for session in tqdm(sessions, f"Staging DICOM sessions found in '{files_path}'"): - try: - session_staging_dir = staging_dir.joinpath(*session.staging_relpath) - if session_staging_dir.exists(): - logger.info( - "Skipping %s session as staging directory %s already exists", - session.name, - str(session_staging_dir), + for session in tqdm(sessions, f"Staging resources found in '{files_path}'"): + try: + if associated_files: + session.associate_files( + associated_files, + spaces_to_underscores=spaces_to_underscores, + ) + if deidentify: + deidentified_session = session.deidentify( + deidentified_dir, + copy_mode=copy_mode, + ) + if delete: + session.unlink() + session = deidentified_session + # We save the session into a temporary "pre-stage" directory first before + # moving them into the final "staged" directory. This is to prevent the + # files being transferred/deleted until the saved session is in a final state. + _, saved_dir = session.save( + prestage_dir, + available_projects=project_list, + copy_mode=copy_mode, ) - continue - # Identify theDeidentify files if necessary and save them to the staging directory - session.stage( - staging_dir, - associated_file_groups=associated_files, - remove_original=delete, - deidentify=deidentify, - project_list=project_list, - spaces_to_underscores=spaces_to_underscores, + if "INVALID" in saved_dir.name: + saved_dir.rename(invalid_dir / saved_dir.relative_to(prestage_dir)) + else: + saved_dir.rename(staged_dir / saved_dir.relative_to(prestage_dir)) + if delete: + session.unlink() + except Exception as e: + if not raise_errors: + logger.error( + f"Skipping '{session.name}' session due to error in staging: \"{e}\"" + f"\n{traceback.format_exc()}\n\n" + ) + continue + else: + raise + + if loop: + while True: + start_time = datetime.datetime.now() + do_stage() + end_time = datetime.datetime.now() + elapsed_seconds = (end_time - start_time).total_seconds() + sleep_time = loop - elapsed_seconds + logger.info( + "Stage took %s seconds, waiting another %s seconds before running " + "again (loop every %s seconds)", + elapsed_seconds, + sleep_time, + loop, ) - except Exception as e: - if not raise_errors: - logger.error( - f"Skipping '{session.name}' session due to error in staging: \"{e}\"" - f"\n{traceback.format_exc()}\n\n" - ) - continue - else: - raise + time.sleep(loop) + else: + do_stage() if __name__ == "__main__": diff --git a/xnat_ingest/cli/upload.py b/xnat_ingest/cli/upload.py index 7de9b4a..7526b2f 100644 --- a/xnat_ingest/cli/upload.py +++ b/xnat_ingest/cli/upload.py @@ -1,35 +1,36 @@ from pathlib import Path -import shutil -import os -import datetime import traceback import typing as ty -from collections import defaultdict import tempfile -from operator import itemgetter +import time +import datetime import subprocess as sp import click from tqdm import tqdm -from natsort import natsorted -import xnat # type: ignore[import-untyped] -import boto3 -import paramiko +import xnat from fileformats.generic import File -from frametree.core.frameset import FrameSet # type: ignore[import-untyped] -from frametree.xnat import Xnat # type: ignore[import-untyped] -from xnat.exceptions import XNATResponseError # type: ignore[import-untyped] +from frametree.core.frameset import FrameSet +from frametree.xnat import Xnat +from xnat.exceptions import XNATResponseError from xnat_ingest.cli.base import cli from xnat_ingest.session import ImagingSession +from xnat_ingest.resource import ImagingResource from xnat_ingest.utils import ( logger, - LogFile, - LogEmail, - MailServer, + LoggerConfig, set_logger_handling, - get_checksums, - calculate_checksums, StoreCredentials, ) +from xnat_ingest.upload_helpers import ( + get_xnat_session, + get_xnat_resource, + get_xnat_checksums, + calculate_checksums, + iterate_s3_sessions, + remove_old_files_on_s3, + remove_old_files_on_ssh, + dir_older_than, +) @cli.command( @@ -47,74 +48,41 @@ PASSWORD is the password for the XNAT user, alternatively "XNAT_INGEST_PASS" env. var """, ) -@click.argument("staged", type=str, envvar="XNAT_INGEST_UPLOAD_STAGED") -@click.argument("server", type=str, envvar="XNAT_INGEST_UPLOAD_HOST") -@click.argument("user", type=str, envvar="XNAT_INGEST_UPLOAD_USER") -@click.option("--password", default=None, type=str, envvar="XNAT_INGEST_UPLOAD_PASS") +@click.argument("staged", type=str, envvar="XINGEST_STAGED") +@click.argument("server", type=str, envvar="XINGEST_HOST") +@click.argument("user", type=str, envvar="XINGEST_USER") +@click.option("--password", default=None, type=str, envvar="XINGEST_PASS") @click.option( - "--log-level", - default="info", - type=str, - envvar="XNAT_INGEST_UPLOAD_LOGLEVEL", - help=("The level of the logging printed to stdout"), -) -@click.option( - "--log-file", - "log_files", - default=None, - type=LogFile.cli_type, - nargs=2, - metavar=" ", + "--logger", + "loggers", multiple=True, - envvar="XNAT_INGEST_UPLOAD_LOGFILE", - help=( - 'Location to write the output logs to, defaults to "upload-logs" in the ' - "export directory" - ), -) -@click.option( - "--log-email", - "log_emails", - type=LogEmail.cli_type, + type=LoggerConfig.cli_type, + envvar="XINGEST_LOGGERS", nargs=3, - metavar="
", - multiple=True, - envvar="XNAT_INGEST_UPLOAD_LOGEMAIL", - help=( - "Email(s) to send logs to. When provided in an environment variable, " - "mail and log level are delimited by ',' and separate destinations by ';'" - ), + default=(), + metavar=" ", + help=("Setup handles to capture logs that are generated"), ) @click.option( - "--add-logger", + "--additional-logger", + "additional_loggers", type=str, multiple=True, default=(), - envvar="XNAT_INGEST_UPLOAD_LOGGERS", + envvar="XINGEST_ADDITIONALLOGGERS", help=( "The loggers to use for logging. By default just the 'xnat-ingest' logger is used. " "But additional loggers can be included (e.g. 'xnat') can be " "specified here" ), ) -@click.option( - "--mail-server", - type=MailServer.cli_type, - metavar=" ", - default=None, - envvar="XNAT_INGEST_UPLOAD_MAILSERVER", - help=( - "the mail server to send logger emails to. When provided in an environment variable, " - "args are delimited by ';'" - ), -) @click.option( "--always-include", "-i", default=(), type=str, multiple=True, - envvar="XNAT_INGEST_UPLOAD_ALWAYSINCLUDE", + envvar="XINGEST_ALWAYSINCLUDE", help=( "Scan types to always include in the upload, regardless of whether they are" "specified in a column or not. Specified using the scan types IANA mime-type or " @@ -133,7 +101,7 @@ "--store-credentials", type=StoreCredentials.cli_type, metavar=" ", - envvar="XNAT_INGEST_UPLOAD_STORE_CREDENTIALS", + envvar="XINGEST_STORE_CREDENTIALS", default=None, nargs=2, help="Credentials to use to access of data stored in remote stores (e.g. AWS S3)", @@ -142,24 +110,21 @@ "--temp-dir", type=Path, default=None, - envvar="XNAT_INGEST_UPLOAD_TEMPDIR", + envvar="XINGEST_TEMPDIR", help="The directory to use for temporary downloads (i.e. from s3)", ) @click.option( - "--use-manifest/--dont-use-manifest", + "--require-manifest/--dont-require-manifest", default=None, - envvar="XNAT_INGEST_UPLOAD_REQUIRE_MANIFEST", - help=( - "Whether to use the manifest file in the staged sessions to load the " - "directory structure. By default it is used if present and ignore if not there" - ), + envvar="XINGEST_REQUIRE_MANIFEST", + help=("Whether to require manifest files in the staged resources or not"), type=bool, ) @click.option( "--clean-up-older-than", type=int, metavar="", - envvar="XNAT_INGEST_UPLOAD_CLEANUP_OLDER_THAN", + envvar="XINGEST_CLEANUP_OLDER_THAN", default=0, help="The number of days to keep files in the remote store for", ) @@ -167,14 +132,14 @@ "--verify-ssl/--dont-verify-ssl", type=bool, default=True, - envvar="XNAT_INGEST_UPLOAD_VERIFY_SSL", + envvar="XINGEST_VERIFY_SSL", help="Whether to verify the SSL certificate of the XNAT server", ) @click.option( "--use-curl-jsession/--dont-use-curl-jsession", type=bool, default=False, - envvar="XNAT_INGEST_UPLOAD_USE_CURL_JSESSION", + envvar="XINGEST_USE_CURL_JSESSION", help=( "Whether to use CURL to create a JSESSION token to authenticate with XNAT. This is " "used to work around a strange authentication issue when running within a Kubernetes " @@ -185,41 +150,57 @@ "--method", type=click.Choice(["per_file", "tar_memory", "tgz_memory", "tar_file", "tgz_file"]), default="tgz_file", - envvar="XNAT_INGEST_UPLOAD_METHOD", + envvar="XINGEST_METHOD", help=( "The method to use to upload the files to XNAT. Passed through to XNATPy and controls " "whether directories are tarred and/or gzipped before they are uploaded, by default " "'tgz_file' is used" ), ) +@click.option( + "--wait-period", + type=int, + default=0, + envvar="XINGEST_WAIT_PERIOD", + help=( + "The number of seconds to wait since the last file modification in sessions " + "in the S3 bucket or source file-system directory before uploading them to " + "avoid uploading partial sessions" + ), +) +@click.option( + "--loop", + type=int, + default=None, + envvar="XINGEST_LOOP", + help="Run the staging process continuously every LOOP seconds", +) def upload( staged: str, server: str, user: str, password: str, - log_level: str, - log_files: ty.List[LogFile], - log_emails: ty.List[LogEmail], - mail_server: MailServer, + loggers: ty.List[LoggerConfig], + additional_loggers: ty.List[str], always_include: ty.Sequence[str], - add_logger: ty.List[str], raise_errors: bool, - store_credentials: ty.Tuple[str, str], + store_credentials: StoreCredentials, temp_dir: ty.Optional[Path], - use_manifest: bool, + require_manifest: bool, clean_up_older_than: int, verify_ssl: bool, use_curl_jsession: bool, method: str, -): + wait_period: int, + loop: int | None, +) -> None: set_logger_handling( - log_level=log_level, - log_emails=log_emails, - log_files=log_files, - mail_server=mail_server, - add_logger=add_logger, + logger_configs=loggers, + additional_loggers=additional_loggers, ) + + # Set the directory to create temporary files/directories in away from system default if temp_dir: tempfile.tempdir = str(temp_dir) @@ -231,376 +212,227 @@ def upload( verify_ssl=verify_ssl, ) - if use_curl_jsession: - jsession = sp.check_output( - [ - "curl", - "-X", - "PUT", - "-d", - f"username={user}&password={password}", - f"{server}/data/services/auth", - ] - ).decode("utf-8") - xnat_repo.connection.depth = 1 - xnat_repo.connection.session = xnat.connect( - server, user=user, jsession=jsession - ) - - with xnat_repo.connection: - - def xnat_session_exists(project_id, subject_id, visit_id): - try: - xnat_repo.connection.projects[project_id].subjects[ - subject_id - ].experiments[ - ImagingSession.make_session_id(project_id, subject_id, visit_id) + def do_upload() -> None: + if use_curl_jsession: + jsession = sp.check_output( + [ + "curl", + "-X", + "PUT", + "-d", + f"username={user}&password={password}", + f"{server}/data/services/auth", ] - except KeyError: - return False - else: - logger.info( - "Skipping session '%s-%s-%s' as it already exists on XNAT", - project_id, - subject_id, - visit_id, - ) - return True - - project_ids = set() - - if staged.startswith("s3://"): - # List sessions stored in s3 bucket - s3 = boto3.resource( - "s3", - aws_access_key_id=store_credentials.access_key, - aws_secret_access_key=store_credentials.access_secret, + ).decode("utf-8") + xnat_repo.connection.depth = 1 + xnat_repo.connection.session = xnat.connect( + server, user=user, jsession=jsession ) - bucket_name, prefix = staged[5:].split("/", 1) - bucket = s3.Bucket(bucket_name) - if not prefix.endswith("/"): - prefix += "/" - all_objects = bucket.objects.filter(Prefix=prefix) - session_objs = defaultdict(list) - for obj in all_objects: - if obj.key.endswith("/"): - continue # skip directories - path_parts = obj.key[len(prefix) :].split("/") - session_ids = tuple(path_parts[:3]) - project_ids.add(session_ids[0]) - session_objs[session_ids].append((path_parts[3:], obj)) - for ids, objs in list(session_objs.items()): - if xnat_session_exists(*ids): - logger.info( - "Skipping session '%s' as it already exists on XNAT", ids - ) - del session_objs[ids] + with xnat_repo.connection: - num_sessions = len(session_objs) - - if temp_dir: - tmp_download_dir = temp_dir / "xnat-ingest-download" - tmp_download_dir.mkdir(parents=True, exist_ok=True) + num_sessions: int + sessions: ty.Iterable[Path] + if staged.startswith("s3://"): + sessions = iterate_s3_sessions( + staged, store_credentials, temp_dir, wait_period=wait_period + ) + # bit of a hack: number of sessions is the first item in the iterator + num_sessions = next(sessions) # type: ignore[assignment] else: - tmp_download_dir = Path(tempfile.mkdtemp()) - - def iter_staged_sessions(): - for ids, objs in session_objs.items(): - # Just in case the manifest file is not included in the list of objects - # we recreate the project/subject/sesssion directory structure - session_tmp_dir = tmp_download_dir.joinpath(*ids) - session_tmp_dir.mkdir(parents=True, exist_ok=True) - for relpath, obj in tqdm( - objs, - desc=f"Downloading scans in {':'.join(ids)} session from S3 bucket", - ): - obj_path = session_tmp_dir.joinpath(*relpath) - obj_path.parent.mkdir(parents=True, exist_ok=True) - logger.debug("Downloading %s to %s", obj, obj_path) - with open(obj_path, "wb") as f: - bucket.download_fileobj(obj.key, f) - yield session_tmp_dir - shutil.rmtree( - session_tmp_dir - ) # Delete the tmp session after the upload - - logger.info("Found %d sessions in S3 bucket '%s'", num_sessions, staged) - sessions = iter_staged_sessions() - logger.debug("Created sessions iterator") - else: - sessions = [] - for project_dir in Path(staged).iterdir(): - for subject_dir in project_dir.iterdir(): - for session_dir in subject_dir.iterdir(): - if not xnat_session_exists( - project_dir.name, subject_dir.name, session_dir.name - ): - sessions.append(session_dir) - project_ids.add(project_dir.name) - num_sessions = len(sessions) - logger.info( - "Found %d sessions in staging directory '%s'", num_sessions, staged - ) - - # Check for dataset definitions on XNAT if an always_include option is not - # provided - if not always_include: - missing_datasets = set() - for project_id in project_ids: - try: - dataset = FrameSet.load(project_id, xnat_repo) - except Exception: - missing_datasets.add(project_id) - else: - logger.debug( - "Found dataset definition for '%s' project", project_id - ) - if missing_datasets: - raise ValueError( - "Either an '--always-include' option must be provided or dataset " - "definitions must be present on XNAT for the following projects " - f"({missing_datasets}) in order to upload the sessions" + sessions = [] + for session_dir in Path(staged).iterdir(): + if dir_older_than(session_dir, wait_period): + sessions.append(session_dir) + else: + logger.info( + "Skipping '%s' session as it has been modified recently", + session_dir, + ) + num_sessions = len(sessions) + logger.info( + "Found %d sessions in staging directory to stage'%s'", + num_sessions, + staged, ) - for session_staging_dir in tqdm( - sessions, - total=num_sessions, - desc=f"Processing staged sessions found in '{staged}'", - ): - session = ImagingSession.load( - session_staging_dir, use_manifest=use_manifest - ) - try: - if "MR" in session.modalities: - SessionClass = xnat_repo.connection.classes.MrSessionData - default_scan_modality = "MR" - elif "PT" in session.modalities: - SessionClass = xnat_repo.connection.classes.PetSessionData - default_scan_modality = "PT" - elif "CT" in session.modalities: - SessionClass = xnat_repo.connection.classes.CtSessionData - default_scan_modality = "CT" - else: - raise RuntimeError( - f"Found the following unsupported modalities {session.modalities}, " - "in the session. Must contain one of 'MR', 'PT' or 'CT'" - ) + framesets: dict[str, FrameSet] = {} - # Create corresponding session on XNAT - xproject = xnat_repo.connection.projects[session.project_id] + for session_staging_dir in tqdm( + sessions, + total=num_sessions, + desc=f"Processing staged sessions found in '{staged}'", + ): - # Access Arcana dataset associated with project - try: - dataset = FrameSet.load(session.project_id, xnat_repo) - except Exception as e: - logger.warning( - "Did not load dataset definition (%s) from %s project " - "on %s. Only the scan types specified in --always-include", - e, - session.project_id, - server, - ) - dataset = None - - xsubject = xnat_repo.connection.classes.SubjectData( - label=session.subject_id, parent=xproject + session = ImagingSession.load( + session_staging_dir, + require_manifest=require_manifest, ) try: - xsession = xproject.experiments[session.session_id] - except KeyError: - if "MR" in session.modalities: - SessionClass = xnat_repo.connection.classes.MrSessionData - elif "PT" in session.modalities: - SessionClass = xnat_repo.connection.classes.PetSessionData - elif "CT" in session.modalities: - SessionClass = xnat_repo.connection.classes.CtSessionData - else: - raise RuntimeError( - "Found the following unsupported modalities in " - f"{session.name}: {session.modalities}" + # Create corresponding session on XNAT + xproject = xnat_repo.connection.projects[session.project_id] + + # Access Arcana frameset associated with project + try: + frameset = framesets[session.project_id] + except KeyError: + try: + frameset = FrameSet.load(session.project_id, xnat_repo) + except Exception as e: + if not always_include: + logger.error( + "Did not load frameset definition (%s) from %s project " + "on %s. Either '--always-include' flag must be used or " + "the frameset must be defined on XNAT using the `frametree` " + "command line tool (see https://arcanaframework.github.io/frametree/).", + e, + session.project_id, + xnat_repo.server, + ) + continue + else: + frameset = None + framesets[session.project_id] = frameset + + xsession = get_xnat_session(session, xproject) + + # Anonymise DICOMs and save to directory prior to upload + if always_include: + logger.info( + f"Including {always_include} scans/files in upload from '{session.name}' to " + f"{session.path} regardless of whether they are explicitly specified" ) - xsession = SessionClass(label=session.session_id, parent=xsubject) - session_path = ( - f"{session.project_id}:{session.subject_id}:{session.visit_id}" - ) - - # Anonymise DICOMs and save to directory prior to upload - if always_include: - logger.info( - f"Including {always_include} scans/files in upload from '{session.name}' to " - f"{session_path} regardless of whether they are explicitly specified" - ) - for scan_id, scan_type, resource_name, scan in tqdm( - natsorted( - session.select_resources( - dataset, - always_include=always_include, + for resource in tqdm( + sorted( + session.select_resources( + frameset, always_include=always_include + ) ), - key=itemgetter(0), - ), - f"Uploading scans found in {session.name}", - ): - if scan.metadata: - image_type = scan.metadata.get("ImageType") - if image_type and image_type[:2] == ["DERIVED", "SECONDARY"]: - modality = "SC" - resource_name = "secondary" - else: - modality = scan.metadata.get( - "Modality", default_scan_modality + f"Uploading resources found in {session.name}", + ): + xresource = get_xnat_resource(resource, xsession) + if xresource is None: + logger.info( + "Skipping '%s' resource as it is already uploaded", + resource.path, ) - else: - modality = default_scan_modality - if modality == "SC": - ScanClass = xnat_repo.connection.classes.ScScanData - elif modality == "MR": - ScanClass = xnat_repo.connection.classes.MrScanData - elif modality == "PT": - ScanClass = xnat_repo.connection.classes.PetScanData - elif modality == "CT": - ScanClass = xnat_repo.connection.classes.CtScanData - else: - if SessionClass is xnat_repo.connection.classes.PetSessionData: - ScanClass = xnat_repo.connection.classes.PetScanData - elif SessionClass is xnat_repo.connection.classes.CtSessionData: - ScanClass = xnat_repo.connection.classes.CtScanData + continue # skipping as resource already exists + if isinstance(resource.fileset, File): + for fspath in resource.fileset.fspaths: + xresource.upload(str(fspath), fspath.name) else: - ScanClass = xnat_repo.connection.classes.MrScanData - logger.info( - "Can't determine modality of %s-%s scan, defaulting to the " - "default for %s sessions, %s", - scan_id, - scan_type, - SessionClass, - ScanClass, + # Temporarily move the manifest file out of the way so it + # doesn't get uploaded + manifest_file = ( + resource.fileset.parent / ImagingResource.MANIFEST_FNAME + ) + moved_manifest_file = ( + resource.fileset.parent.parent + / ImagingResource.MANIFEST_FNAME + ) + if manifest_file.exists(): + manifest_file.rename(moved_manifest_file) + # Upload the contents of the resource to XNAT + xresource.upload_dir(resource.fileset.parent, method=method) + # Move the manifest file back again + if moved_manifest_file.exists(): + moved_manifest_file.rename(manifest_file) + logger.debug("retrieving checksums for %s", xresource) + remote_checksums = get_xnat_checksums(xresource) + logger.debug("calculating checksums for %s", xresource) + calc_checksums = calculate_checksums(resource.fileset) + if remote_checksums != calc_checksums: + extra_keys = set(remote_checksums) - set(calc_checksums) + missing_keys = set(calc_checksums) - set(remote_checksums) + mismatching = [ + k + for k, v in calc_checksums.items() + if v != remote_checksums[k] + ] + raise RuntimeError( + "Checksums do not match after upload of " + f"'{resource.path}' resource.\n" + f"Extra keys were {extra_keys}\n" + f"Missing keys were {missing_keys}\n" + f"Mismatching files were {mismatching}" + ) + logger.info(f"Uploaded '{resource.path}' in '{session.name}'") + logger.info(f"Successfully uploaded all files in '{session.name}'") + # Extract DICOM metadata + logger.info("Extracting metadata from DICOMs on XNAT..") + try: + xnat_repo.connection.put( + f"/data/experiments/{xsession.id}?pullDataFromHeaders=true" ) - logger.debug("Creating scan %s in %s", scan_id, session_path) - xscan = ScanClass(id=scan_id, type=scan_type, parent=xsession) - logger.debug( - "Creating resource %s in %s in %s", - resource_name, - scan_id, - session_path, - ) - xresource = xscan.create_resource(resource_name) - if isinstance(scan, File): - for fspath in scan.fspaths: - xresource.upload(str(fspath), fspath.name) - else: - xresource.upload_dir(scan.parent, method=method) - logger.debug("retrieving checksums for %s", xresource) - remote_checksums = get_checksums(xresource) - logger.debug("calculating checksums for %s", xresource) - calc_checksums = calculate_checksums(scan) - if remote_checksums != calc_checksums: - mismatching = [ - k - for k, v in remote_checksums.items() - if v != calc_checksums[k] - ] - raise RuntimeError( - "Checksums do not match after upload of " - f"'{session.name}:{scan_id}:{resource_name}' resource. " - f"Mismatching files were {mismatching}" + except XNATResponseError as e: + logger.warning( + f"Failed to extract metadata from DICOMs in '{session.name}': {e}" ) - logger.info(f"Uploaded '{scan_id}' in '{session.name}'") - logger.info(f"Successfully uploaded all files in '{session.name}'") - # Extract DICOM metadata - logger.info("Extracting metadata from DICOMs on XNAT..") - try: - xnat_repo.connection.put( - f"/data/experiments/{xsession.id}?pullDataFromHeaders=true" - ) - except XNATResponseError as e: - logger.warning( - f"Failed to extract metadata from DICOMs in '{session.name}': {e}" - ) - try: - xnat_repo.connection.put( - f"/data/experiments/{xsession.id}?fixScanTypes=true" - ) - except XNATResponseError as e: - logger.warning(f"Failed to fix scan types in '{session.name}': {e}") - try: - xnat_repo.connection.put( - f"/data/experiments/{xsession.id}?triggerPipelines=true" - ) - except XNATResponseError as e: - logger.warning( - f"Failed to trigger pipelines in '{session.name}': {e}" - ) - logger.info(f"Succesfully uploaded all files in '{session.name}'") - except Exception as e: - if not raise_errors: - logger.error( - f"Skipping '{session.name}' session due to error in staging: \"{e}\"" - f"\n{traceback.format_exc()}\n\n" - ) - continue - else: - raise - - if use_curl_jsession: - xnat_repo.connection.exit() - - if clean_up_older_than: - logger.info( - "Cleaning up files in %s older than %d days", - staged, - clean_up_older_than, - ) - if staged.startswith("s3://"): - remove_old_files_on_s3(remote_store=staged, threshold=clean_up_older_than) - elif "@" in staged: - remove_old_files_on_ssh(remote_store=staged, threshold=clean_up_older_than) - else: - assert False - - -def remove_old_files_on_s3(remote_store: str, threshold: int): - # Parse S3 bucket and prefix from remote store - bucket_name, prefix = remote_store[5:].split("/", 1) - - # Create S3 client - s3_client = boto3.client("s3") - - # List objects in the bucket with the specified prefix - response = s3_client.list_objects_v2(Bucket=bucket_name, Prefix=prefix) - - now = datetime.datetime.now() - - # Iterate over objects and delete files older than the threshold - for obj in response.get("Contents", []): - last_modified = obj["LastModified"] - age = (now - last_modified).days - if age > threshold: - s3_client.delete_object(Bucket=bucket_name, Key=obj["Key"]) - - -def remove_old_files_on_ssh(remote_store: str, threshold: int): - # Parse SSH server and directory from remote store - server, directory = remote_store.split("@", 1) - - # Create SSH client - ssh_client = paramiko.SSHClient() - ssh_client.load_system_host_keys() - ssh_client.connect(server) - - # Execute find command to list files in the directory - stdin, stdout, stderr = ssh_client.exec_command(f"find {directory} -type f") - - now = datetime.datetime.now() + try: + xnat_repo.connection.put( + f"/data/experiments/{xsession.id}?fixScanTypes=true" + ) + except XNATResponseError as e: + logger.warning( + f"Failed to fix scan types in '{session.name}': {e}" + ) + try: + xnat_repo.connection.put( + f"/data/experiments/{xsession.id}?triggerPipelines=true" + ) + except XNATResponseError as e: + logger.warning( + f"Failed to trigger pipelines in '{session.name}': {e}" + ) + logger.info(f"Succesfully uploaded all files in '{session.name}'") + except Exception as e: + if not raise_errors: + logger.error( + f"Skipping '{session.name}' session due to error in staging: \"{e}\"" + f"\n{traceback.format_exc()}\n\n" + ) + continue + else: + raise - # Iterate over files and delete files older than the threshold - for file_path in stdout.read().decode().splitlines(): - last_modified = datetime.datetime.fromtimestamp(os.path.getmtime(file_path)) - age = (now - last_modified).days - if age > threshold: - ssh_client.exec_command(f"rm {file_path}") + if use_curl_jsession: + xnat_repo.connection.exit() - ssh_client.close() + if clean_up_older_than: + logger.info( + "Cleaning up files in %s older than %d days", + staged, + clean_up_older_than, + ) + if staged.startswith("s3://"): + remove_old_files_on_s3( + remote_store=staged, threshold=clean_up_older_than + ) + elif "@" in staged: + remove_old_files_on_ssh( + remote_store=staged, threshold=clean_up_older_than + ) + else: + assert False + + if loop: + while True: + start_time = datetime.datetime.now() + do_upload() + end_time = datetime.datetime.now() + elapsed_seconds = (end_time - start_time).total_seconds() + sleep_time = loop - elapsed_seconds + logger.info( + "Stage took %s seconds, waiting another %s seconds before running " + "again (loop every %s seconds)", + elapsed_seconds, + sleep_time, + loop, + ) + time.sleep(loop) + else: + do_upload() if __name__ == "__main__": diff --git a/xnat_ingest/dicom.py b/xnat_ingest/dicom.py deleted file mode 100644 index 2708b38..0000000 --- a/xnat_ingest/dicom.py +++ /dev/null @@ -1,56 +0,0 @@ -import typing as ty -import subprocess as sp - -# import re -import pydicom - -# from fileformats.core import FileSet -# from fileformats.application import Dicom -# from fileformats.extras.application.medical import dicom_read_metadata - -dcmedit_path: ty.Optional[str] -try: - dcmedit_path = sp.check_output("which dcmedit", shell=True).decode("utf-8").strip() -except sp.CalledProcessError: - dcmedit_path = None - -dcminfo_path: ty.Optional[str] -try: - dcminfo_path = sp.check_output("which dcminfo", shell=True).decode("utf-8").strip() -except sp.CalledProcessError: - dcminfo_path = None - - -def tag2keyword(tag: ty.Tuple[str, str]) -> str: - return pydicom.datadict.dictionary_keyword((int(tag[0]), int(tag[1]))) - - -def keyword2tag(keyword: str) -> ty.Tuple[str, str]: - tag = pydicom.datadict.tag_for_keyword(keyword) - if not tag: - raise ValueError(f"Could not find tag for keyword '{keyword}'") - tag_str = hex(tag)[2:] - return (f"{tag_str[:-4].zfill(4)}", tag_str[-4:]) - - -class DicomField: - name = "dicom_field" - - def __init__(self, keyword_or_tag): - # Get the tag associated with the keyword - try: - self.tag = keyword2tag(keyword_or_tag) - except ValueError: - try: - self.keyword = tag2keyword(keyword_or_tag) - except ValueError: - raise ValueError( - f'Could not parse "{keyword_or_tag}" as a DICOM keyword or tag' - ) - else: - self.tag = keyword_or_tag - else: - self.keyword = keyword_or_tag - - def __str__(self): - return f"'{self.keyword}' field ({','.join(self.tag)})" diff --git a/xnat_ingest/exceptions.py b/xnat_ingest/exceptions.py index e1b1711..2efbd9b 100644 --- a/xnat_ingest/exceptions.py +++ b/xnat_ingest/exceptions.py @@ -1,18 +1,24 @@ -class UnsupportedModalityError(Exception): - def __init__(self, msg): +class XnatIngestError(Exception): + def __init__(self, msg: str): self.msg = msg -class StagingError(Exception): - def __init__(self, msg): - self.msg = msg +class UnsupportedModalityError(XnatIngestError): ... -class ImagingSessionParseError(StagingError): - def __init__(self, msg): - self.msg = msg +class StagingError(XnatIngestError): ... -class UploadError(Exception): - def __init__(self, msg): - self.msg = msg +class ImagingSessionParseError(StagingError): ... + + +class UploadError(XnatIngestError): ... + + +class DifferingCheckumsException(XnatIngestError): ... + + +class UpdatedFilesException(DifferingCheckumsException): ... + + +class IncompleteCheckumsException(DifferingCheckumsException): ... diff --git a/xnat_ingest/resource.py b/xnat_ingest/resource.py new file mode 100644 index 0000000..56161d4 --- /dev/null +++ b/xnat_ingest/resource.py @@ -0,0 +1,194 @@ +import typing as ty +import logging +import hashlib +from pathlib import Path +from typing_extensions import Self +import shutil +import attrs +from fileformats.application import Json +from fileformats.core import FileSet +from .exceptions import ( + IncompleteCheckumsException, + DifferingCheckumsException, +) +import xnat_ingest.scan + +logger = logging.getLogger("xnat-ingest") + + +@attrs.define +class ImagingResource: + name: str + fileset: FileSet + checksums: dict[str, str] = attrs.field(eq=False, repr=False) + scan: "xnat_ingest.scan.ImagingScan" = attrs.field( + default=None, eq=False, repr=False + ) + + @checksums.default + def calculate_checksums(self) -> dict[str, str]: + return self.fileset.hash_files(crypto=hashlib.md5) + + @property + def datatype(self) -> ty.Type[FileSet]: + return type(self.fileset) + + @property + def metadata(self) -> ty.Mapping[str, ty.Any]: + return self.fileset.metadata # type: ignore[no-any-return] + + @property + def mime_like(self) -> str: + return self.fileset.mime_like + + def __lt__(self, other: Self) -> bool: + try: + scan_id = int(self.scan.id) + except ValueError: + scan_id = self.scan.id # type: ignore[assignment] + try: + other_scan_id = int(other.scan.id) + except ValueError: + other_scan_id = other.scan.id # type: ignore[assignment] + return (scan_id, self.name) < (other_scan_id, other.name) + + def newer_than_or_equal(self, other: Self) -> bool: + return all(s >= m for s, m in zip(self.fileset.mtimes, other.fileset.mtimes)) + + def save( + self, + dest_dir: Path, + copy_mode: FileSet.CopyMode = FileSet.CopyMode.copy, + calculate_checksums: bool = True, + overwrite: bool | None = None, + ) -> Self: + """Save the resource to a directory + + Parameters + ---------- + dest_dir: Path + The directory to save the resource + copy_mode: FileSet.CopyMode + The method to copy the files + calculate_checksums: bool + Whether to calculate the checksums of the files + overwrite: bool + Whether to overwrite the resource if it already exists, if None then the files + are overwritten if they are newer than the ones saved, otherwise a warning is + issued, if False an exception will be raised, if True then the resource is + saved regardless of the files being newer + + Returns + ------- + ImagingResource + The saved resource + + Raises + ------ + FileExistsError + If the resource already exists and overwrite is False or None and the files + are not newer + """ + resource_dir = dest_dir / self.name + checksums = ( + self.calculate_checksums() if calculate_checksums else self.checksums + ) + if resource_dir.exists(): + try: + loaded = self.load(resource_dir, require_manifest=False) + if loaded.checksums == checksums: + return loaded + elif overwrite is None and not self.newer_than_or_equal(loaded): + logger.warning( + f"Resource '{self.name}' already exists in '{dest_dir}' but " + "the files are not older than the ones to be be saved" + ) + elif overwrite: + logger.warning( + f"Resource '{self.name}' already exists in '{dest_dir}', overwriting" + ) + shutil.rmtree(resource_dir) + else: + if overwrite is None: + msg = "and the files are not older than the ones to be be saved" + else: + msg = "" + raise FileExistsError( + f"Resource '{self.name}' already exists in '{dest_dir}'{msg}, set " + "'overwrite' to True to overwrite regardless of file times" + ) + except DifferingCheckumsException: + logger.warning( + f"Resource '{self.name}' already exists in '{dest_dir}', but it is " + "incomplete, overwriting" + ) + shutil.rmtree(resource_dir) + saved_fileset = self.fileset.copy(resource_dir, mode=copy_mode, trim=True) + manifest = {"datatype": self.fileset.mime_like, "checksums": checksums} + Json.new(resource_dir / self.MANIFEST_FNAME, manifest) + return type(self)(name=self.name, fileset=saved_fileset, checksums=checksums) + + @classmethod + def load( + cls, + resource_dir: Path, + require_manifest: bool = True, + check_checksums: bool = True, + ) -> Self: + """Load a resource from a directory, reading the manifest file if it exists. + If the manifest file doesn't exist and 'require_manifest' is True then an + exception is raised, if it is False, then a generic FileSet object is loaded + from the files that were found + """ + manifest_file = resource_dir / cls.MANIFEST_FNAME + if manifest_file.exists(): + manifest = Json(manifest_file).load() + checksums = manifest["checksums"] + datatype: ty.Type[FileSet] = FileSet.from_mime(manifest["datatype"]) # type: ignore[assignment] + elif require_manifest: + raise FileNotFoundError( + f"Manifest file not found in '{resource_dir}' resource, set " + "'require_manifest' to False to ignore and load as a generic FileSet object" + ) + else: + checksums = None + datatype = FileSet + fileset = datatype( + p for p in resource_dir.iterdir() if p.name != cls.MANIFEST_FNAME + ) + resource = cls(name=resource_dir.name, fileset=fileset, checksums=checksums) + if checksums is not None and check_checksums: + resource.check_checksums() + return resource + + def check_checksums(self) -> None: + calc_checksums = self.calculate_checksums() + if calc_checksums != self.checksums: + if all(v == self.checksums[k] for k, v in calc_checksums.items()): + missing = list(set(self.checksums) - set(calc_checksums)) + raise IncompleteCheckumsException( + f"Files saved with '{self.name}' resource are incomplete " + f"according to saved checksums, missing {missing}" + ) + + differing = [ + k for k in self.checksums if calc_checksums[k] != self.checksums[k] + ] + raise DifferingCheckumsException( + f"Checksums don't match those saved with '{self.name}' " + f"resource: {differing}" + ) + + def unlink(self) -> None: + """Remove all files in the file-set, the object will be unusable after this""" + for fspath in self.fileset.fspaths: + if fspath.is_file(): + fspath.unlink() + else: + shutil.rmtree(fspath) + + @property + def path(self) -> str: + return self.scan.path + ":" + self.name + + MANIFEST_FNAME = "MANIFEST.json" diff --git a/xnat_ingest/scan.py b/xnat_ingest/scan.py new file mode 100644 index 0000000..e18cbe5 --- /dev/null +++ b/xnat_ingest/scan.py @@ -0,0 +1,96 @@ +import typing as ty +import re +from pathlib import Path +from typing_extensions import Self +import logging +import attrs +from fileformats.core import FileSet +from .resource import ImagingResource +from .utils import AssociatedFiles +import xnat_ingest.session + +logger = logging.getLogger("xnat-ingest") + + +def scan_type_converter(scan_type: str) -> str: + "Ensure there aren't any special characters that aren't valid file/dir paths" + return re.sub(r"[\"\*\/\:\<\>\?\\\|\+\,\.\;\=\[\]]+", "", scan_type) + + +def scan_resources_converter( + resources: dict[str, ImagingResource | FileSet] +) -> ty.Dict[str, ImagingResource]: + return { + scan_type_converter(k): ( + v if isinstance(v, ImagingResource) else ImagingResource(k, v) + ) + for k, v in resources.items() + } + + +@attrs.define +class ImagingScan: + """Representation of a scan to be uploaded to XNAT + + Parameters + ---------- + id: str + the ID of the scan on XNAT + type: str + the scan type/description + """ + + id: str + type: str = attrs.field(converter=scan_type_converter) + resources: ty.Dict[str, ImagingResource] = attrs.field( + factory=dict, converter=scan_resources_converter + ) + associated: AssociatedFiles | None = None + session: "xnat_ingest.session.ImagingSession" = attrs.field( + default=None, eq=False, repr=False + ) + + def __contains__(self, resource_name: str) -> bool: + return resource_name in self.resources + + def __getitem__(self, resource_name: str) -> ImagingResource: + return self.resources[resource_name] + + def __attrs_post_init__(self) -> None: + for resource in self.resources.values(): + resource.scan = self + + def new_empty(self) -> Self: + return type(self)(self.id, self.type) + + def save( + self, + dest_dir: Path, + copy_mode: FileSet.CopyMode = FileSet.CopyMode.hardlink_or_copy, + ) -> Self: + # Ensure scan type is a valid directory name + saved = self.new_empty() + scan_dir = dest_dir / f"{self.id}-{self.type}" + scan_dir.mkdir(parents=True, exist_ok=True) + for resource in self.resources.values(): + saved_resource = resource.save(scan_dir, copy_mode=copy_mode) + saved_resource.scan = saved + saved.resources[saved_resource.name] = saved_resource + return saved + + @classmethod + def load(cls, scan_dir: Path, require_manifest: bool = True) -> Self: + scan_id, scan_type = scan_dir.name.split("-", 1) + scan = cls(scan_id, scan_type) + for resource_dir in scan_dir.iterdir(): + if resource_dir.is_dir(): + resource = ImagingResource.load( + resource_dir, require_manifest=require_manifest + ) + resource.scan = scan + scan.resources[resource.name] = resource + return scan + + @property + def path(self) -> str: + return self.session.path + ":" + self.id + "-" + self.type diff --git a/xnat_ingest/session.py b/xnat_ingest/session.py index 118a561..4479895 100644 --- a/xnat_ingest/session.py +++ b/xnat_ingest/session.py @@ -2,71 +2,30 @@ import re from glob import glob import logging -import os.path -import subprocess as sp from functools import cached_property -import shutil import random import string -import platform -from copy import deepcopy from itertools import chain from collections import defaultdict, Counter from pathlib import Path from typing_extensions import Self import attrs from tqdm import tqdm -import yaml -import pydicom -from fileformats.application import Dicom -from fileformats.medimage import DicomSeries -from fileformats.core import from_paths, FileSet, DataType, from_mime, to_mime +from fileformats.medimage import MedicalImage, DicomSeries +from fileformats.core import from_paths, FileSet, from_mime from frametree.core.frameset import FrameSet # type: ignore[import-untyped] -from frametree.core.axes import Axes # type: ignore[import-untyped] -from frametree.core.row import DataRow # type: ignore[import-untyped] -from frametree.core.store import Store # type: ignore[import-untyped] -from frametree.core.entry import DataEntry # type: ignore[import-untyped] -from frametree.core.tree import DataTree # type: ignore[import-untyped] from frametree.core.exceptions import FrameTreeDataMatchError # type: ignore[import-untyped] from .exceptions import ImagingSessionParseError, StagingError -from .utils import add_exc_note, transform_paths, AssociatedFiles -from .dicom import dcmedit_path +from .utils import AssociatedFiles, invalid_path_chars_re +from .scan import ImagingScan +from .resource import ImagingResource logger = logging.getLogger("xnat-ingest") -def scan_type_converter(scan_type: str) -> str: - "Ensure there aren't any special characters that aren't valid file/dir paths" - return re.sub(r"[\"\*\/\:\<\>\?\\\|\+\,\.\;\=\[\]]+", "", scan_type) - - -@attrs.define -class ImagingScan: - """Representation of a scan to be uploaded to XNAT - - Parameters - ---------- - id: str - the ID of the scan on XNAT - type: str - the scan type/description - """ - - id: str - type: str = attrs.field(converter=scan_type_converter) - resources: ty.Dict[str, FileSet] = attrs.field() - associated: bool = False - - def __contains__(self, resource_name): - return resource_name in self.resources - - def __getitem__(self, resource_name): - return self.resources[resource_name] - - def scans_converter( scans: ty.Union[ty.Sequence[ImagingScan], ty.Dict[str, ImagingScan]] -): +) -> dict[str, ImagingScan]: if isinstance(scans, ty.Sequence): duplicates = [i for i, c in Counter(s.id for s in scans).items() if c > 1] if duplicates: @@ -86,17 +45,21 @@ class ImagingSession: validator=attrs.validators.instance_of(dict), ) + def __attrs_post_init__(self) -> None: + for scan in self.scans.values(): + scan.session = self + id_escape_re = re.compile(r"[^a-zA-Z0-9_]+") def __getitem__(self, fieldname: str) -> ty.Any: return self.metadata[fieldname] @property - def name(self): + def name(self) -> str: return f"{self.project_id}-{self.subject_id}-{self.visit_id}" @property - def invalid_ids(self): + def invalid_ids(self) -> bool: return ( self.project_id.startswith("INVALID") or self.subject_id.startswith("INVALID") @@ -104,37 +67,46 @@ def invalid_ids(self): ) @property - def staging_relpath(self): - return [self.project_id, self.subject_id, self.visit_id] + def path(self) -> str: + return ":".join([self.project_id, self.subject_id, self.visit_id]) + + @property + def staging_relpath(self) -> list[str]: + return ["-".join([self.project_id, self.subject_id, self.visit_id])] @property - def session_id(self): + def session_id(self) -> str: return self.make_session_id(self.project_id, self.subject_id, self.visit_id) @classmethod - def make_session_id(cls, project_id, subject_id, visit_id): + def make_session_id(cls, project_id: str, subject_id: str, visit_id: str) -> str: return f"{subject_id}_{visit_id}" @cached_property - def modalities(self) -> ty.Set[str]: - modalities = self.metadata["Modality"] - if not isinstance(modalities, str): - modalities = set( - tuple(m) if not isinstance(m, str) else m for m in modalities - ) - return modalities + def modalities(self) -> str | tuple[str, ...]: + modalities_metadata = self.metadata["Modality"] + if isinstance(modalities_metadata, str): + return modalities_metadata + modalities: set[str] = set() + for modality in modalities_metadata: + if isinstance(modality, str): + modalities.add(modality) + else: + assert isinstance(modality, ty.Iterable) + modalities.update(modality) + return tuple(modalities) @property - def parent_dirs(self) -> ty.Set[Path]: + def primary_parents(self) -> ty.Set[Path]: "Return parent directories for all resources in the session" - return set(r.parent for r in self.resources) + return set(r.fileset.parent for r in self.primary_resources) @property - def resources(self) -> ty.List[FileSet]: + def resources(self) -> ty.List[ImagingResource]: return [r for p in self.scans.values() for r in p.resources.values()] @property - def primary_resources(self) -> ty.List[FileSet]: + def primary_resources(self) -> ty.List[ImagingResource]: return [ r for s in self.scans.values() @@ -142,11 +114,19 @@ def primary_resources(self) -> ty.List[FileSet]: if not s.associated ] + def new_empty(self) -> Self: + """Return a new empty session with the same IDs as the current session""" + return type(self)( + project_id=self.project_id, + subject_id=self.subject_id, + visit_id=self.visit_id, + ) + def select_resources( self, dataset: ty.Optional[FrameSet], always_include: ty.Sequence[str] = (), - ) -> ty.Iterator[ty.Tuple[str, str, str, FileSet]]: + ) -> ty.Iterator[ImagingResource]: """Returns selected resources that match the columns in the dataset definition Parameters @@ -174,7 +154,7 @@ def select_resources( "Either 'dataset' or 'always_include' must be specified to select " f"appropriate resources to upload from {self.name} session" ) - store = MockStore(self) + store = ImagingSessionMockStore(self) uploaded = set() for mime_like in always_include: @@ -187,10 +167,10 @@ def select_resources( f"{mime_like!r} does not correspond to a file format ({fileformat})" ) for scan in self.scans.values(): - for resource_name, fileset in scan.resources.items(): - if isinstance(fileset, fileformat): - uploaded.add((scan.id, resource_name)) - yield scan.id, scan.type, resource_name, fileset + for resource in scan.resources.values(): + if isinstance(resource.fileset, fileformat): + uploaded.add((scan.id, resource.name)) + yield resource if dataset is not None: for column in dataset.columns.values(): try: @@ -212,12 +192,18 @@ def select_resources( always_include, ) continue - fileset = column.datatype(scan.resources[resource_name]) + resource = scan.resources[resource_name] + if not isinstance(resource.fileset, column.datatype): + resource = ImagingResource( + name=resource_name, + fileset=column.datatype(resource.fileset), + scan=scan, + ) uploaded.add((scan.id, resource_name)) - yield scan_id, scan.type, entry.uri[1], column.datatype(entry.item) + yield resource @cached_property - def metadata(self): + def metadata(self) -> dict[str, ty.Any]: primary_resources = self.primary_resources all_keys = [list(d.metadata.keys()) for d in primary_resources if d.metadata] common_keys = [ @@ -337,12 +323,6 @@ def from_paths( multiple_sessions: ty.DefaultDict[str, ty.Set[ty.Tuple[str, str, str]]] = ( defaultdict(set) ) - multiple_scan_types: ty.DefaultDict[ - ty.Tuple[str, str, str, str], ty.Set[str] - ] = defaultdict(set) - multiple_resources: ty.DefaultDict[ - ty.Tuple[str, str, str, str, str], ty.Set[str] - ] = defaultdict(set) for resource in tqdm( resources, "Sorting resources into XNAT tree structure...", @@ -356,13 +336,13 @@ def get_id(field_type: str, field_name: str) -> str: else: index = None try: - value = str(resource.metadata[field_name]) + value = resource.metadata[field_name] except KeyError: if session_uid and field_type in ("project", "subject", "visit"): value = ( - "INVALID-MISSING-" + "INVALID_MISSING_" + field_type.upper() - + "-" + + "_" + "".join( random.choices( string.ascii_letters + string.digits, k=8 @@ -375,7 +355,9 @@ def get_id(field_type: str, field_name: str) -> str: ) if index is not None: value = value[index] - return value + value_str = str(value) + value_str = invalid_path_chars_re.sub("_", value_str) + return value_str if not project_id: project_id = get_id("project", project_field) @@ -412,318 +394,78 @@ def get_id(field_type: str, field_name: str) -> str: multiple_sessions[session_uid].add( (session.project_id, session.subject_id, session.visit_id) ) - - try: - scan = session.scans[scan_id] - except KeyError: - scan = ImagingScan(id=scan_id, type=scan_type, resources={}) - session.scans[scan_id] = scan - else: - if scan.type != scan_type: - # Record all issues with the scan types for raising exception at the end - multiple_scan_types[ - (project_id, subject_id, visit_id, scan_id) - ].add(scan_type) - - if resource_id in scan.resources: - multiple_resources[ - (project_id, subject_id, visit_id, scan_id, scan_type) - ].add(resource_id) - scan.resources[resource_id] = resource - + session.add_resource(scan_id, scan_type, resource_id, resource) if multiple_sessions: raise ImagingSessionParseError( - "Multiple sessions found with the same project/subject/visit ID triplets: " + "Multiple session UIDs found with the same project/subject/visit ID triplets: " + "\n".join( f"{i} -> {p}:{s}:{v}" for i, (p, s, v) in multiple_sessions.items() ) ) - - if multiple_scan_types: - raise ImagingSessionParseError( - "Multiple scans found with the same project/subject/visit/scan ID " - "quadruplets: " - + "\n".join( - f"{p}:{s}:{v}:{sc} -> " + ", ".join(st) - for (p, s, v, sc), st in multiple_scan_types.items() - ) - ) - if multiple_resources: - raise ImagingSessionParseError( - "Multiple resources found with the same project/subject/visit/scan/resource " - "ID quintuplets: " - + "\n".join( - f"{p}:{s}:{v}:{sc}:{r} -> " + ", ".join(rs) - for (p, s, v, sc, r), rs in multiple_resources.items() - ) - ) - return list(sessions.values()) - @classmethod - def load( - cls, session_dir: Path, use_manifest: ty.Optional[bool] = None - ) -> "ImagingSession": - """Loads a session from a directory. Assumes that the name of the directory is - the name of the session dir and the parent directory is the subject ID and the - grandparent directory is the project ID. The scan information is loaded from a YAML - along with the scan type, resources and fileformats. If the YAML file is not found - or `use_manifest` is set to True, the session is loaded based on the directory - structure. + def deidentify( + self, dest_dir: Path, copy_mode: FileSet.CopyMode = FileSet.CopyMode.copy + ) -> Self: + """Creates a new session with deidentified images Parameters ---------- - session_dir : Path - the path to the directory where the session is saved - use_manifest: bool, optional - determines whether to load the session based on YAML manifest or to infer - it from the directory structure. If True the manifest is expected and an error - will be raised if it isn't present, if False the manifest is ignored and if - None the manifest is used if present, otherwise the directory structure is used. + dest_dir : Path + the directory to save the deidentified files into + copy_mode : FileSet.CopyMode, optional + the mode to use to copy the files that don't need to be deidentified, + by default FileSet.CopyMode.copy Returns ------- ImagingSession - the loaded session + a new session with deidentified images """ - project_id = session_dir.parent.parent.name - subject_id = session_dir.parent.name - visit_id = session_dir.name - yaml_file = session_dir / cls.MANIFEST_FILENAME - if yaml_file.exists() and use_manifest is not False: - # Load session from YAML file metadata - try: - with open(yaml_file) as f: - dct = yaml.load(f, Loader=yaml.SafeLoader) - except Exception as e: - add_exc_note( - e, - f"Loading saved session from {yaml_file}, please check that it " - "is a valid YAML file", - ) - raise e - scans = [] - for scan_id, scan_dict in dct["scans"].items(): - scans.append( - ImagingScan( - id=scan_id, - type=scan_dict["type"], - resources={ - n: from_mime(d["datatype"])( # type: ignore[call-arg, misc] - session_dir.joinpath(*p.split("/")) - for p in d["fspaths"] - ) - for n, d in scan_dict["resources"].items() - }, + # Create a new session to save the deidentified files into + deidentified = self.new_empty() + for scan in self.scans.values(): + for resource_name, resource in scan.resources.items(): + resource_dest_dir = dest_dir / scan.id / resource_name + if not isinstance(resource.fileset, MedicalImage): + deid_resource = resource.fileset.copy( + resource_dest_dir, mode=copy_mode, new_stem=resource_name ) - ) - dct["scans"] = scans - session = cls( - project_id=project_id, - subject_id=subject_id, - visit_id=visit_id, - **dct, - ) - elif use_manifest is not True: - # Load session based on directory structure - scans = [] - for scan_dir in session_dir.iterdir(): - if not scan_dir.is_dir(): - continue - scan_id, scan_type = scan_dir.name.split("-", 1) - scan_resources = {} - for resource_dir in scan_dir.iterdir(): - scan_resources[resource_dir.name] = FileSet(resource_dir.iterdir()) - scans.append( - ImagingScan( - id=scan_id, - type=scan_type, - resources=scan_resources, + else: + deid_resource = resource.fileset.deidentify( + resource_dest_dir, copy_mode=copy_mode, new_stem=resource_name ) + deidentified.add_resource( + scan.id, + scan.type, + resource_name, + deid_resource, ) - session = cls( - scans=scans, - project_id=project_id, - subject_id=subject_id, - visit_id=visit_id, - ) - else: - raise FileNotFoundError( - f"Did not find manifest file '{yaml_file}' in session directory " - f"{session_dir}. If you want to fallback to load the session based on " - "the directory structure instead, set `use_manifest` to None." - ) - return session - - def save(self, save_dir: Path, just_manifest: bool = False) -> "ImagingSession": - """Save the project/subject/session IDs loaded from the session to a YAML file, - so they can be manually overridden. + return deidentified - Parameters - ---------- - save_dir: Path - the path to save the session metadata into (NB: the data is typically also - stored in the directory structure of the session, but this is not necessary) - just_manifest : bool, optional - just save the manifest file, not the data, false by default - - Returns - ------- - saved : ImagingSession - a copy of the session with updated paths - """ - scans = {} - saved = deepcopy(self) - session_dir = ( - save_dir / self.project_id / self.subject_id / self.visit_id - ).absolute() - session_dir.mkdir(parents=True, exist_ok=True) - for scan in self.scans.values(): - resources_dict = {} - for resource_name, fileset in scan.resources.items(): - resource_dir = session_dir / f"{scan.id}-{scan.type}" / resource_name - if not just_manifest: - # If data is not already in the save directory, copy it there - logger.debug( - "Checking whether fileset paths %s already inside " - "the save directory %s", - str(fileset.parent), - resource_dir, - ) - if not fileset.parent.absolute().is_relative_to( - resource_dir.absolute() - ): - resource_dir.mkdir(parents=True, exist_ok=True) - fileset = fileset.copy( - resource_dir, mode=fileset.CopyMode.hardlink_or_copy - ) - saved.scans[scan.id].resources[resource_name] = fileset - resources_dict[resource_name] = { - "datatype": to_mime(type(fileset), official=False), - "fspaths": [ - # Ensure it is a relative path using POSIX forward slashes - str(p.absolute().relative_to(session_dir)).replace("\\", "/") - for p in fileset.fspaths - ], - } - scans[scan.id] = { - "type": scan.type, - "resources": resources_dict, - } - yaml_file = session_dir / self.MANIFEST_FILENAME - with open(yaml_file, "w") as f: - yaml.dump( - {"scans": scans}, - f, - ) - return saved - - def stage( + def associate_files( self, - dest_dir: Path, - associated_file_groups: ty.Collection[AssociatedFiles] = (), - remove_original: bool = False, - deidentify: bool = True, - project_list: ty.Optional[ty.List[str]] = None, - spaces_to_underscores: bool = False, - ) -> Self: - r"""Stages and deidentifies files by removing the fields listed `FIELDS_TO_ANONYMISE` and - replacing birth date with 01/01/ and returning new imaging session + patterns: ty.List[AssociatedFiles], + spaces_to_underscores: bool = True, + ) -> None: + """Adds files associated with the primary files to the session Parameters ---------- - dest_dir : Path - destination directory to save the deidentified files. The session will be saved - to a directory with the project, subject and session IDs as subdirectories of - this directory, along with the scans manifest - work_dir : Path, optional - the directory the staged sessions are created in before they are moved into - the staging directory - associated_file_groups : Collection[AssociatedFiles], optional - Glob pattern used to select the non-dicom files to include in the session. Note - that the pattern is relative to the parent directory containing the DICOM files - NOT the current working directory. - The glob pattern can contain string template placeholders corresponding to DICOM - metadata (e.g. '{PatientName.family_name}_{PatientName.given_name}'), which - are substituted before the string is used to glob the non-DICOM files. In - order to deidentify the filenames, the pattern must explicitly reference all - identifiable fields in string template placeholders. By default, None - - Used to extract the scan ID & type/resource from the associated filename. Should - be a regular-expression (Python syntax) with named groups called 'id' and 'type', e.g. - '[^\.]+\.[^\.]+\.(?P\d+)\.(?P\w+)\..*' - remove_original : bool, optional - delete original files after they have been staged, false by default - deidentify : bool, optional - deidentify the scans in the staging process, true by default - project_list : list[str], optional - list of available projects in the store, used to check whether the project ID - is valid + patterns : list[AssociatedFiles] + list of patterns to associate files with the primary files in the session spaces_to_underscores : bool, optional when building associated file globs, convert spaces underscores in fields extracted from source file metadata, false by default - - Returns - ------- - ImagingSession - a deidentified session with updated paths """ - if not dcmedit_path: - logger.warning( - "Did not find `dcmedit` tool from the MRtrix package on the system path, " - "de-identification will be performed by pydicom instead and may be slower" - ) - - staged_scans = [] - staged_metadata = {} - if project_list is None or self.project_id in project_list: - project_dir = self.project_id - else: - project_dir = "INVALID-UNRECOGNISED-PROJECT-" + self.project_id - session_dir = dest_dir / project_dir / self.subject_id / self.visit_id - session_dir.mkdir(parents=True) - session_metadata = self.metadata - for scan in tqdm( - self.scans.values(), f"Staging DICOM sessions to {session_dir}" - ): - staged_resources: ty.Dict[str, FileSet] = {} - for resource_name, fileset in scan.resources.items(): - # Ensure scan type is a valid directory name - scan_dir = session_dir / f"{scan.id}-{scan.type}" / resource_name - scan_dir.mkdir(parents=True, exist_ok=True) - if isinstance(fileset, DicomSeries): - staged_dicom_paths = [] - for dicom in fileset.contents: - if deidentify: - dicom_ext = dicom.decomposed_fspaths()[0][-1] - staged_fspath = self.deidentify_dicom( - dicom, - scan_dir - / (dicom.metadata["SOPInstanceUID"] + dicom_ext), - remove_original=remove_original, - ) - elif remove_original: - staged_fspath = dicom.move(scan_dir) - else: - staged_fspath = dicom.copy(scan_dir) - staged_dicom_paths.append(staged_fspath) - staged_resource = DicomSeries(staged_dicom_paths) - # Add to the combined metadata dictionary - staged_metadata.update(staged_resource.metadata) - else: - continue # associated files will be staged later - staged_resources[resource_name] = staged_resource - staged_scans.append( - ImagingScan(id=scan.id, type=scan.type, resources=staged_resources) - ) - for associated_files in associated_file_groups: + for associated_files in patterns: # substitute string templates int the glob template with values from the # DICOM metadata to construct a glob pattern to select files associated # with current session associated_fspaths: ty.Set[Path] = set() - for parent_dir in self.parent_dirs: + for parent_dir in self.primary_parents: assoc_glob = str( - parent_dir / associated_files.glob.format(**session_metadata) + parent_dir / associated_files.glob.format(**self.metadata) ) if spaces_to_underscores: assoc_glob = assoc_glob.replace(" ", "_") @@ -738,54 +480,9 @@ def stage( associated_files.glob, ) - tmpdir = session_dir / ".tmp" - tmpdir.mkdir() - - if deidentify: - # Transform the names of the paths to remove any identiable information - if associated_files.glob.startswith("/") or ( - platform.system() == "Windows" - and re.match(r"[a-zA-Z]:\\", associated_files.glob) - ): - assoc_glob_pattern = associated_files.glob - else: - assoc_glob_pattern = ( - str(parent_dir) + os.path.sep + associated_files.glob - ) - transformed_fspaths = transform_paths( - list(associated_fspaths), - assoc_glob_pattern, - session_metadata, - staged_metadata, - spaces_to_underscores=spaces_to_underscores, - ) - staged_associated_fspaths = [] - - for old, new in tqdm( - zip(associated_fspaths, transformed_fspaths), - "Anonymising associated file names", - ): - dest_path = tmpdir / new.name - if Dicom.matches(old): - self.deidentify_dicom( - old, dest_path, remove_original=remove_original - ) - elif remove_original: - logger.debug("Moving %s to %s", old, dest_path) - old.rename(dest_path) - else: - logger.debug("Copying %s to %s", old, dest_path) - shutil.copyfile(old, dest_path) - staged_associated_fspaths.append(dest_path) - else: - staged_associated_fspaths = list(associated_fspaths) - # Identify scan id, type and resource names from deidentified file paths - assoc_scans = {} assoc_re = re.compile(associated_files.identity_pattern) - for fspath in tqdm( - staged_associated_fspaths, "sorting files into resources" - ): + for fspath in tqdm(associated_fspaths, "sorting files into resources"): match = assoc_re.match(str(fspath)) if not match: raise RuntimeError( @@ -793,290 +490,183 @@ def stage( f"did not match file path {fspath}" ) scan_id = match.group("id") - resource = match.group("resource") + resource_name = match.group("resource") try: scan_type = match.group("type") except IndexError: scan_type = scan_id - if scan_id not in assoc_scans: - assoc_resources: ty.DefaultDict[str, ty.List[Path]] = defaultdict( - list - ) - assoc_scans[scan_id] = (scan_type, assoc_resources) - else: - prev_scan_type, assoc_resources = assoc_scans[scan_id] - if scan_type != prev_scan_type: - raise RuntimeError( - f"Mismatched scan types '{scan_type}' and " - f"'{prev_scan_type}' for scan ID '{scan_id}'" - ) - assoc_resources[resource].append(fspath) - for scan_id, (scan_type, scan_resources_dict) in tqdm( - assoc_scans.items(), "moving associated files to staging directory" - ): - scan_resources = {} - for resource_name, fspaths in scan_resources_dict.items(): - if resource_name in self.scans.get(scan_id, []): - raise RuntimeError( - f"Conflict between existing resource and associated files " - f"to stage {scan_id}:{resource_name}" - ) - resource_dir = session_dir / scan_id / resource_name - resource_dir.mkdir(parents=True) - resource_fspaths = [] - for fspath in fspaths: - dest_path = resource_dir / fspath.name - if remove_original or deidentify: - # If deidentify is True then the files will have been copied - # to a temp folder and we can just rename them to their - # final destination - fspath.rename(dest_path) - else: - shutil.copyfile(fspath, dest_path) - resource_fspaths.append(dest_path) - scan_resources[resource_name] = associated_files.datatype( - resource_fspaths - ) - staged_scans.append( - ImagingScan( - id=scan_id, - type=scan_type, - resources=scan_resources, - associated=True, - ) + self.add_resource( + scan_id, + scan_type, + resource_name, + associated_files.datatype(fspath), + associated=associated_files, ) - os.rmdir(tmpdir) # Should be empty - staged = type(self)( - project_id=self.project_id, - subject_id=self.subject_id, - visit_id=self.visit_id, - scans=staged_scans, - ) - staged.save(dest_dir, just_manifest=True) - # If original scans have been moved clear the scans dictionary - if remove_original: - self.scans = {} - return staged - - def deidentify_dicom( - self, dicom_file: Path, new_path: Path, remove_original: bool = False - ) -> Path: - if dcmedit_path: - # Copy to new path - shutil.copyfile(dicom_file, new_path) - # Replace date of birth date with 1st of Jan - args = [ - dcmedit_path, - "-quiet", - "-anonymise", - str(new_path), - ] - sp.check_call(args) - else: - dcm = pydicom.dcmread(dicom_file) - dcm.PatientBirthDate = "" # dcm.PatientBirthDate[:4] + "0101" - for field in self.FIELDS_TO_CLEAR: - try: - elem = dcm[field] # type: ignore - except KeyError: - pass - else: - elem.value = "" - dcm.save_as(new_path) - if remove_original: - os.unlink(dicom_file) - return new_path - - FIELDS_TO_CLEAR = [ - ("0008", "0014"), # Instance Creator UID - ("0008", "1111"), # Referenced Performed Procedure Step SQ - ("0008", "1120"), # Referenced Patient SQ - ("0008", "1140"), # Referenced Image SQ - ("0008", "0096"), # Referring Physician Identification SQ - ("0008", "1032"), # Procedure Code SQ - ("0008", "1048"), # Physician(s) of Record - ("0008", "1049"), # Physician(s) of Record Identification SQ - ("0008", "1050"), # Performing Physicians' Name - ("0008", "1052"), # Performing Physician Identification SQ - ("0008", "1060"), # Name of Physician(s) Reading Study - ("0008", "1062"), # Physician(s) Reading Study Identification SQ - ("0008", "1110"), # Referenced Study SQ - ("0008", "1111"), # Referenced Performed Procedure Step SQ - ("0008", "1250"), # Related Series SQ - ("0008", "9092"), # Referenced Image Evidence SQ - ("0008", "0080"), # Institution Name - ("0008", "0081"), # Institution Address - ("0008", "0082"), # Institution Code Sequence - ("0008", "0092"), # Referring Physician's Address - ("0008", "0094"), # Referring Physician's Telephone Numbers - ("0008", "009C"), # Consulting Physician's Name - ("0008", "1070"), # Operators' Name - ("0010", "4000"), # Patient Comments - ("0010", "0010"), # Patient's Name - ("0010", "0021"), # Issuer of Patient ID - ("0010", "0032"), # Patient's Birth Time - ("0010", "0050"), # Patient's Insurance Plan Code SQ - ("0010", "0101"), # Patient's Primary Language Code SQ - ("0010", "1000"), # Other Patient IDs - ("0010", "1001"), # Other Patient Names - ("0010", "1002"), # Other Patient IDs SQ - ("0010", "1005"), # Patient's Birth Name - ("0010", "1010"), # Patient's Age - ("0010", "1040"), # Patient's Address - ("0010", "1060"), # Patient's Mother's Birth Name - ("0010", "1080"), # Military Rank - ("0010", "1081"), # Branch of Service - ("0010", "1090"), # Medical Record Locator - ("0010", "2000"), # Medical Alerts - ("0010", "2110"), # Allergies - ("0010", "2150"), # Country of Residence - ("0010", "2152"), # Region of Residence - ("0010", "2154"), # Patient's Telephone Numbers - ("0010", "2160"), # Ethnic Group - ("0010", "2180"), # Occupation - ("0010", "21A0"), # Smoking Status - ("0010", "21B0"), # Additional Patient History - ("0010", "21C0"), # Pregnancy Status - ("0010", "21D0"), # Last Menstrual Date - ("0010", "21F0"), # Patient's Religious Preference - ("0010", "2203"), # Patient's Sex Neutered - ("0010", "2297"), # Responsible Person - ("0010", "2298"), # Responsible Person Role - ("0010", "2299"), # Responsible Organization - ("0020", "9221"), # Dimension Organization SQ - ("0020", "9222"), # Dimension Index SQ - ("0038", "0010"), # Admission ID - ("0038", "0011"), # Issuer of Admission ID - ("0038", "0060"), # Service Episode ID - ("0038", "0061"), # Issuer of Service Episode ID - ("0038", "0062"), # Service Episode Description - ("0038", "0500"), # Patient State - ("0038", "0100"), # Pertinent Documents SQ - ("0040", "0260"), # Performed Protocol Code SQ - ("0088", "0130"), # Storage Media File-Set ID - ("0088", "0140"), # Storage Media File-Set UID - ("0400", "0561"), # Original Attributes Sequence - ("5200", "9229"), # Shared Functional Groups SQ - ] - MANIFEST_FILENAME = "MANIFEST.yaml" - - -@attrs.define -class MockStore(Store): - """Mock data store so we can use the column.match_entry method on the "entries" in - the data row - """ - - session: ImagingSession - - @property - def row(self): - return DataRow( - ids={DummyAxes._: None}, - frameset=FrameSet(id=None, store=self, hierarchy=[], axes=DummyAxes), - frequency=DummyAxes._, - ) - - def populate_row(self, row: DataRow): - """ - Populate a row with all data entries found in the corresponding node in the data - store (e.g. files within a directory, scans within an XNAT session) using the - ``DataRow.add_entry`` method. Within a node/row there are assumed to be two types - of entries, "primary" entries (e.g. acquired scans) common to all analyses performed - on the dataset and "derivative" entries corresponding to intermediate outputs - of previously performed analyses. These types should be stored in separate - namespaces so there is no chance of a derivative overriding a primary data item. - - The name of the dataset/analysis a derivative was generated by is appended to - to a base path, delimited by "@", e.g. "brain_mask@my_analysis". The dataset - name is left blank by default, in which case "@" is just appended to the - derivative path, i.e. "brain_mask@". + def add_resource( + self, + scan_id: str, + scan_type: str, + resource_name: str, + fileset: FileSet, + overwrite: bool = False, + associated: AssociatedFiles | None = None, + ) -> None: + """Adds a resource to the imaging session Parameters ---------- - row : DataRow - The row to populate with entries + scan_id : str + the ID of the scan to add the resource to + scan_type : str + short description of the type of the scan + resource_name: str + the name of the resource to add + fileset : FileSet + the fileset to add as the resource + overwrite : bool + whether to overwrite existing resource + associated : bool, optional + whether the resource is primary or associated to a primary resource """ - for scan_id, scan in self.session.scans.items(): - for resource_name, resource in scan.resources.items(): - row.add_entry( - path=scan.type + "/" + resource_name, - datatype=type(resource), - uri=(scan_id, resource_name), + try: + scan = self.scans[scan_id] + except KeyError: + scan = self.scans[scan_id] = ImagingScan( + id=scan_id, type=scan_type, associated=associated, session=self + ) + else: + if scan.type != scan_type: + raise ValueError( + f"Non-matching scan types ({scan.type} and {scan_type}) " + f"for scan ID {scan_id}" + ) + if associated != scan.associated: + raise ValueError( + f"Non-matching associated files ({scan.associated} and {associated}) " + f"for scan ID {scan_id}" ) + if resource_name in scan.resources and not overwrite: + raise KeyError( + f"Clash between resource names ('{resource_name}') for {scan_id} scan" + ) + scan.resources[resource_name] = ImagingResource( + name=resource_name, fileset=fileset, scan=scan + ) - def get(self, entry: DataEntry, datatype: type) -> DataType: - """ - Gets the data item corresponding to the given entry + @classmethod + def load( + cls, + session_dir: Path, + require_manifest: bool = True, + check_checksums: bool = True, + ) -> Self: + """Loads a session from a directory. Assumes that the name of the directory is + the name of the session dir and the parent directory is the subject ID and the + grandparent directory is the project ID. The scan information is loaded from a YAML + along with the scan type, resources and fileformats. If the YAML file is not found + or `use_manifest` is set to True, the session is loaded based on the directory + structure. Parameters ---------- - entry : DataEntry - the data entry to update - datatype : type - the datatype to interpret the entry's item as + session_dir : Path + the path to the directory where the session is saved + require_manifiest: bool, optional + whether a manifest file is required to load the resources in the session, + if true, resources will only be loaded if the manifest file is found, + if false, resources will be loaded as FileSet types and checksums will not + be checked, by default True + check_checksums: bool, optional + whether to check the checksums of the files in the session, by default True Returns ------- - item : DataType - the item stored within the specified entry + ImagingSession + the loaded session """ - scan_id, resource_name = entry.uri - return datatype(self.session.scans[scan_id][resource_name]) - - ###################################### - # The following methods can be empty # - ###################################### - - def populate_tree(self, tree: DataTree): - pass - - def connect(self) -> ty.Any: - pass - - def disconnect(self, session: ty.Any): - pass + project_id, subject_id, visit_id = session_dir.name.split("-") + session = cls( + project_id=project_id, + subject_id=subject_id, + visit_id=visit_id, + ) + for scan_dir in session_dir.iterdir(): + if scan_dir.is_dir(): + scan = ImagingScan.load(scan_dir, require_manifest=require_manifest) + scan.session = session + session.scans[scan.id] = scan + return session - def create_data_tree( + def save( self, - id: str, - leaves: ty.List[ty.Tuple[str, ...]], - hierarchy: ty.List[str], - axes: type, - **kwargs, - ): - raise NotImplementedError - - ################################### - # The following shouldn't be used # - ################################### - - def put(self, item: DataType, entry: DataEntry) -> DataType: - raise NotImplementedError - - def put_provenance(self, provenance: ty.Dict[str, ty.Any], entry: DataEntry): - raise NotImplementedError + dest_dir: Path, + available_projects: ty.Optional[ty.List[str]] = None, + copy_mode: FileSet.CopyMode = FileSet.CopyMode.hardlink_or_copy, + ) -> tuple[Self, Path]: + r"""Stages and deidentifies files by removing the fields listed `FIELDS_TO_ANONYMISE` and + replacing birth date with 01/01/ and returning new imaging session - def get_provenance(self, entry: DataEntry) -> ty.Dict[str, ty.Any]: - raise NotImplementedError + Parameters + ---------- + dest_dir : Path + destination directory to save the deidentified files. The session will be saved + to a directory with the project, subject and session IDs as subdirectories of + this directory, along with the scans manifest + work_dir : Path, optional + the directory the staged sessions are created in before they are moved into + the staging directory + associated_file_groups : Collection[AssociatedFiles], optional + Glob pattern used to select the non-dicom files to include in the session. Note + that the pattern is relative to the parent directory containing the DICOM files + NOT the current working directory. + The glob pattern can contain string template placeholders corresponding to DICOM + metadata (e.g. '{PatientName.family_name}_{PatientName.given_name}'), which + are substituted before the string is used to glob the non-DICOM files. In + order to deidentify the filenames, the pattern must explicitly reference all + identifiable fields in string template placeholders. By default, None - def save_frameset_definition( - self, dataset_id: str, definition: ty.Dict[str, ty.Any], name: str - ): - raise NotImplementedError + Used to extract the scan ID & type/resource from the associated filename. Should + be a regular-expression (Python syntax) with named groups called 'id' and 'type', e.g. + '[^\.]+\.[^\.]+\.(?P\d+)\.(?P\w+)\..*' + remove_original : bool, optional + delete original files after they have been staged, false by default + deidentify : bool, optional + deidentify the scans in the staging process, true by default + project_list : list[str], optional + list of available projects in the store, used to check whether the project ID + is valid + spaces_to_underscores : bool, optional + when building associated file globs, convert spaces underscores in fields + extracted from source file metadata, false by default - def load_frameset_definition( - self, dataset_id: str, name: str - ) -> ty.Dict[str, ty.Any]: - raise NotImplementedError + Returns + ------- + ImagingSession + a deidentified session with updated paths + Path + the path to the directory where the session is saved + """ + saved = self.new_empty() + if available_projects is None or self.project_id in available_projects: + project_id = self.project_id + else: + project_id = "INVALID_UNRECOGNISED_" + self.project_id + session_dir = dest_dir / "-".join((project_id, self.subject_id, self.visit_id)) + session_dir.mkdir(parents=True, exist_ok=True) + for scan in tqdm(self.scans.values(), f"Staging sessions to {session_dir}"): + saved_scan = scan.save(session_dir, copy_mode=copy_mode) + saved_scan.session = saved + saved.scans[saved_scan.id] = saved_scan + return saved, session_dir - def site_licenses_dataset(self): - raise NotImplementedError + MANIFEST_FILENAME = "MANIFEST.yaml" - def create_entry(self, path: str, datatype: type, row: DataRow) -> DataEntry: - raise NotImplementedError + def unlink(self) -> None: + """Unlink all resources in the session""" + for scan in self.scans.values(): + for resource in scan.resources.values(): + resource.unlink() -class DummyAxes(Axes): - _ = 0b0 +from .store import ImagingSessionMockStore # noqa: E402 diff --git a/xnat_ingest/store.py b/xnat_ingest/store.py new file mode 100644 index 0000000..288afe7 --- /dev/null +++ b/xnat_ingest/store.py @@ -0,0 +1,132 @@ +import typing as ty +import attrs +from fileformats.core import DataType +from frametree.core.frameset import FrameSet # type: ignore[import-untyped] +from frametree.core.axes import Axes # type: ignore[import-untyped] +from frametree.core.row import DataRow # type: ignore[import-untyped] +from frametree.core.store import Store # type: ignore[import-untyped] +from frametree.core.entry import DataEntry # type: ignore[import-untyped] +from frametree.core.tree import DataTree # type: ignore[import-untyped] +from .session import ImagingSession + + +@attrs.define +class ImagingSessionMockStore(Store): # type: ignore[misc] + """Mock data store so we can use the column.match_entry method on the "entries" in + the data row + """ + + session: ImagingSession + + @property + def row(self) -> DataRow: + return DataRow( + ids={DummyAxes._: None}, + frameset=FrameSet(id=None, store=self, hierarchy=[], axes=DummyAxes), + frequency=DummyAxes._, + ) + + def populate_row(self, row: DataRow) -> None: + """ + Populate a row with all data entries found in the corresponding node in the data + store (e.g. files within a directory, scans within an XNAT session) using the + ``DataRow.add_entry`` method. Within a node/row there are assumed to be two types + of entries, "primary" entries (e.g. acquired scans) common to all analyses performed + on the dataset and "derivative" entries corresponding to intermediate outputs + of previously performed analyses. These types should be stored in separate + namespaces so there is no chance of a derivative overriding a primary data item. + + The name of the dataset/analysis a derivative was generated by is appended to + to a base path, delimited by "@", e.g. "brain_mask@my_analysis". The dataset + name is left blank by default, in which case "@" is just appended to the + derivative path, i.e. "brain_mask@". + + Parameters + ---------- + row : DataRow + The row to populate with entries + """ + for scan_id, scan in self.session.scans.items(): + for resource_name, resource in scan.resources.items(): + row.add_entry( + path=scan.type + "/" + resource_name, + datatype=resource.datatype, + uri=(scan_id, resource_name), + ) + + def get(self, entry: DataEntry, datatype: type) -> DataType: + """ + Gets the data item corresponding to the given entry + + Parameters + ---------- + entry : DataEntry + the data entry to update + datatype : type + the datatype to interpret the entry's item as + + Returns + ------- + item : DataType + the item stored within the specified entry + """ + scan_id, resource_name = entry.uri + return datatype(self.session.scans[scan_id][resource_name].fileset) # type: ignore[no-any-return] + + ###################################### + # The following methods can be empty # + ###################################### + + def populate_tree(self, tree: DataTree) -> None: + pass + + def connect(self) -> ty.Any: + pass + + def disconnect(self, session: ty.Any) -> None: + pass + + def create_data_tree( + self, + id: str, + leaves: ty.List[ty.Tuple[str, ...]], + hierarchy: ty.List[str], + axes: type, + **kwargs: ty.Any, + ) -> DataTree: + raise NotImplementedError + + ################################### + # The following shouldn't be used # + ################################### + + def put(self, item: DataType, entry: DataEntry) -> DataType: + raise NotImplementedError + + def put_provenance( + self, provenance: ty.Dict[str, ty.Any], entry: DataEntry + ) -> None: + raise NotImplementedError + + def get_provenance(self, entry: DataEntry) -> ty.Dict[str, ty.Any]: + raise NotImplementedError + + def save_frameset_definition( + self, dataset_id: str, definition: ty.Dict[str, ty.Any], name: str + ) -> None: + raise NotImplementedError + + def load_frameset_definition( + self, dataset_id: str, name: str + ) -> ty.Dict[str, ty.Any]: + raise NotImplementedError + + def site_licenses_dataset(self) -> FrameSet: + raise NotImplementedError + + def create_entry(self, path: str, datatype: type, row: DataRow) -> DataEntry: + raise NotImplementedError + + +class DummyAxes(Axes): # type: ignore[misc] + _ = 0b0 diff --git a/xnat_ingest/tests/test_cli.py b/xnat_ingest/tests/test_cli.py index 88e0c6f..5d3cdc5 100644 --- a/xnat_ingest/tests/test_cli.py +++ b/xnat_ingest/tests/test_cli.py @@ -9,6 +9,7 @@ import xnat4tests # type: ignore[import-untyped] from frametree.core.cli.store import add as store_add # type: ignore[import-untyped] from xnat_ingest.cli import stage, upload +from xnat_ingest.cli.stage import STAGED_NAME_DEFAULT from xnat_ingest.utils import show_cli_trace from fileformats.medimage import DicomSeries from medimages4tests.dummy.dicom.pet.wholebody.siemens.biograph_vision.vr20b import ( # type: ignore[import-untyped] @@ -199,10 +200,11 @@ def test_stage_and_upload( str(associated_files_dir) + "/{PatientName.family_name}_{PatientName.given_name}*.ptd", r".*/[^\.]+.[^\.]+.[^\.]+.(?P\d+)\.[A-Z]+_(?P[^\.]+).*", - "--log-file", - str(log_file), + "--logger", + "file", "info", - "--add-logger", + str(log_file), + "--additional-logger", "xnat", "--raise-errors", "--delete", @@ -218,11 +220,12 @@ def test_stage_and_upload( result = cli_runner( upload, [ - str(staging_dir), - "--log-file", - str(log_file), + str(staging_dir / STAGED_NAME_DEFAULT), + "--logger", + "file", "info", - "--add-logger", + str(log_file), + "--additional-logger", "xnat", "--always-include", "medimage/dicom-series", @@ -230,11 +233,13 @@ def test_stage_and_upload( "--method", "tar_file", "--use-curl-jsession", + "--wait-period", + "0", ], env={ - "XNAT_INGEST_UPLOAD_HOST": xnat_server, - "XNAT_INGEST_UPLOAD_USER": "admin", - "XNAT_INGEST_UPLOAD_PASS": "admin", + "XINGEST_HOST": xnat_server, + "XINGEST_USER": "admin", + "XINGEST_PASS": "admin", }, ) diff --git a/xnat_ingest/tests/test_dicom.py b/xnat_ingest/tests/test_dicom.py index 69a4f72..9411d6b 100644 --- a/xnat_ingest/tests/test_dicom.py +++ b/xnat_ingest/tests/test_dicom.py @@ -12,12 +12,14 @@ @pytest.fixture def dicom_series(scope="module") -> DicomSeries: - return DicomSeries(get_pet_image().iterdir()) + return DicomSeries( + get_pet_image(first_name="GivenName", last_name="FamilyName").iterdir() + ) -@pytest.mark.xfail( - condition=(platform.system() == "Linux"), reason="Not working on ubuntu" -) +# @pytest.mark.xfail( +# condition=(platform.system() == "Linux"), reason="Not working on ubuntu" +# ) def test_mrtrix_dicom_metadata(dicom_series: DicomSeries): keys = [ "AccessionNumber", @@ -30,7 +32,7 @@ def test_mrtrix_dicom_metadata(dicom_series: DicomSeries): dicom_series = DicomSeries(dicom_series, specific_tags=keys) assert not (set(keys + ["SpecificCharacterSet"]) - set(dicom_series.metadata)) - assert dicom_series.metadata["PatientName"] == "GivenName^FamilyName" + assert dicom_series.metadata["PatientName"] == "FamilyName^GivenName" assert dicom_series.metadata["AccessionNumber"] == "987654321" assert dicom_series.metadata["PatientID"] == "Session Label" assert dicom_series.metadata["StudyID"] == "PROJECT_ID" diff --git a/xnat_ingest/tests/test_session.py b/xnat_ingest/tests/test_session.py index 8bb285a..a5709a7 100644 --- a/xnat_ingest/tests/test_session.py +++ b/xnat_ingest/tests/test_session.py @@ -1,5 +1,6 @@ from pathlib import Path import pytest +import typing as ty from fileformats.core import from_mime, FileSet from fileformats.medimage import ( DicomSeries, @@ -26,13 +27,38 @@ get_image as get_statistics_image, __file__ as statistics_src_file, ) -from xnat_ingest.session import ImagingSession, ImagingScan, DummyAxes +from xnat_ingest.session import ImagingSession, ImagingScan +from xnat_ingest.store import DummyAxes from xnat_ingest.utils import AssociatedFiles from conftest import get_raw_data_files FIRST_NAME = "Given Name" LAST_NAME = "FamilyName" +DICOM_COLUMNS: ty.List[ty.Tuple[str, str, str]] = [ + ("pet", "medimage/dicom-series", "PET SWB 8MIN"), + ("topogram", "medimage/dicom-series", "Topogram.*"), + ("atten_corr", "medimage/dicom-series", "AC CT.*"), +] + +RAW_COLUMNS: ty.List[ty.Tuple[str, str, str]] = [ + ( + "listmode", + "medimage/vnd.siemens.biograph128-vision.vr20b.pet-list-mode", + ".*/PET_LISTMODE", + ), + # ( + # "sinogram", + # "medimage/vnd.siemens.biograph128-vision.vr20b.pet-sinogram", + # ".*/PET_EM_SINO", + # ), + ( + "countrate", + "medimage/vnd.siemens.biograph128-vision.vr20b.pet-count-rate", + ".*/PET_COUNTRATE", + ), +] + @pytest.fixture def imaging_session() -> ImagingSession: @@ -109,26 +135,40 @@ def dataset(tmp_path: Path) -> FrameSet: hierarchy=[], axes=DummyAxes, ) - for col_name, col_type, col_pattern in [ - ("pet", "medimage/dicom-series", "PET SWB 8MIN"), - ("topogram", "medimage/dicom-series", "Topogram.*"), - ("atten_corr", "medimage/dicom-series", "AC CT.*"), - ( - "listmode", - "medimage/vnd.siemens.biograph128-vision.vr20b.pet-list-mode", - ".*/LISTMODE", - ), - # ( - # "sinogram", - # "medimage/vnd.siemens.biograph128-vision.vr20b.pet-sinogram", - # ".*/EM_SINO", - # ), - ( - "countrate", - "medimage/vnd.siemens.biograph128-vision.vr20b.pet-count-rate", - ".*/COUNTRATE", - ), - ]: + for col_name, col_type, col_pattern in DICOM_COLUMNS + RAW_COLUMNS: + dataset.add_source(col_name, from_mime(col_type), col_pattern, is_regex=True) + return dataset + + +@pytest.fixture +def raw_frameset(tmp_path: Path) -> FrameSet: + """For use in tests, this method creates a test dataset from the provided + blueprint + + Parameters + ---------- + store: DataStore + the store to make the dataset within + dataset_id : str + the ID of the project/directory within the store to create the dataset + name : str, optional + the name to give the dataset. If provided the dataset is also saved in the + datastore + source_data : Path, optional + path to a directory containing source data to use instead of the dummy + data + **kwargs + passed through to create_dataset + """ + dataset_path = tmp_path / "a-dataset" + store = FileSystem() + dataset = store.create_dataset( + id=dataset_path, + leaves=[], + hierarchy=[], + axes=DummyAxes, + ) + for col_name, col_type, col_pattern in RAW_COLUMNS: dataset.add_source(col_name, from_mime(col_type), col_pattern, is_regex=True) return dataset @@ -150,25 +190,28 @@ def test_session_select_resources( staging_dir = tmp_path / "staging" staging_dir.mkdir() - staged_session = imaging_session.stage( - staging_dir, - associated_file_groups=[ + imaging_session.associate_files( + patterns=[ AssociatedFiles( Vnd_Siemens_Biograph128Vision_Vr20b_PetRawData, str(assoc_dir) + "/{PatientName.family_name}_{PatientName.given_name}*.ptd", - r".*/[^\.]+.[^\.]+.[^\.]+.(?P\d+)\.[A-Z]+_(?P[^\.]+).*", + r".*/[^\.]+.[^\.]+.[^\.]+.(?P\d+)\.(?P[^\.]+).*", ) ], spaces_to_underscores=True, ) - resources = list(staged_session.select_resources(dataset)) + saved_session, saved_dir = imaging_session.save(staging_dir) + + resources_iter = saved_session.select_resources(dataset) + resources = list(resources_iter) assert len(resources) == 5 # 6 - ids, descs, resource_names, scans = zip(*resources) - assert set(ids) == set(("1", "2", "4", "602")) # , "603")) - assert set(descs) == set( + assert set([r.scan.id for r in resources]) == set( + ("1", "2", "4", "602") + ) # , "603")) + assert set([r.scan.type for r in resources]) == set( [ "AC CT 30 SWB HD_FoV", "PET SWB 8MIN", @@ -177,8 +220,10 @@ def test_session_select_resources( # "603", ] ) - assert set(resource_names) == set(("DICOM", "LISTMODE", "COUNTRATE")) # , "EM_SINO" - assert set(type(s) for s in scans) == set( + assert set([r.name for r in resources]) == set( + ("DICOM", "PET_LISTMODE", "PET_COUNTRATE") + ) # , "PET_EM_SINO" + assert set([r.datatype for r in resources]) == set( [ DicomSeries, Vnd_Siemens_Biograph128Vision_Vr20b_PetListMode, @@ -191,7 +236,7 @@ def test_session_select_resources( def test_session_save_roundtrip(tmp_path: Path, imaging_session: ImagingSession): # Save imaging sessions to a temporary directory - saved = imaging_session.save(tmp_path) + saved, _ = imaging_session.save(tmp_path) assert saved is not imaging_session # Calculate where the session should have been saved to @@ -207,12 +252,64 @@ def test_session_save_roundtrip(tmp_path: Path, imaging_session: ImagingSession) rereloaded = ImagingSession.load(session_dir) assert rereloaded == saved - # Load from saved directory, this time only using directory structure instead of - # manifest. Should be the same with the exception of the detected fileformats - loaded_no_manifest = ImagingSession.load(session_dir, use_manifest=False) - for scan in loaded_no_manifest.scans.values(): - for key, resource in list(scan.resources.items()): - if key == "DICOM": - assert isinstance(resource, FileSet) - scan.resources[key] = DicomSeries(resource) - assert loaded_no_manifest == saved + # # Load from saved directory, this time only using directory structure instead of + # # manifest. Should be the same with the exception of the detected fileformats + # loaded_no_manifest = ImagingSession.load(session_dir, require_manifest=False) + # for scan in loaded_no_manifest.scans.values(): + # for key, resource in list(scan.resources.items()): + # if key == "DICOM": + # assert isinstance(resource, FileSet) + # scan.resources[key] = DicomSeries(resource) + # assert loaded_no_manifest == saved + + +def test_stage_raw_data_directly(raw_frameset: FrameSet, tmp_path: Path): + + raw_data_dir = tmp_path / "raw" + raw_data_dir.mkdir() + + num_sessions = 2 + + for i in range(num_sessions): + sess_dir = raw_data_dir / str(i) + sess_dir.mkdir() + get_raw_data_files( + out_dir=sess_dir, + first_name=FIRST_NAME + str(i), + last_name=LAST_NAME + str(i), + StudyID=f"Study{i}", + PatientID=f"Patient{i}", + AccessionNumber=f"AccessionNumber{i}", + StudyInstanceUID=f"StudyInstanceUID{i}", + ) + + imaging_sessions = ImagingSession.from_paths( + f"{raw_data_dir}/**/*.ptd", + datatypes=[Vnd_Siemens_Biograph128Vision_Vr20b_PetRawData], + ) + + staging_dir = tmp_path / "staging" + staging_dir.mkdir() + + staged_sessions = [] + + for imaging_session in imaging_sessions: + staged_sessions.append( + imaging_session.save( + staging_dir, + )[0] + ) + + for staged_session in staged_sessions: + resources = list(staged_session.select_resources(raw_frameset)) + + assert len(resources) == 2 + assert set([r.scan.id for r in resources]) == set(["602"]) + assert set([r.scan.type for r in resources]) == set(["PET Raw Data"]) + assert set(r.name for r in resources) == set(("PET_LISTMODE", "PET_COUNTRATE")) + assert set(type(r.fileset) for r in resources) == set( + [ + Vnd_Siemens_Biograph128Vision_Vr20b_PetListMode, + Vnd_Siemens_Biograph128Vision_Vr20b_PetCountRate, + ] + ) diff --git a/xnat_ingest/upload_helpers.py b/xnat_ingest/upload_helpers.py new file mode 100644 index 0000000..73665b7 --- /dev/null +++ b/xnat_ingest/upload_helpers.py @@ -0,0 +1,371 @@ +from pathlib import Path +import shutil +import os +import datetime +import typing as ty +from collections import defaultdict +import tempfile +from tqdm import tqdm +import hashlib +import pprint +import boto3 +import paramiko +from xnat_ingest.utils import ( + logger, + StoreCredentials, +) +from fileformats.core import FileSet +from .session import ImagingSession +from .resource import ImagingResource + + +def iterate_s3_sessions( + bucket_path: str, + store_credentials: StoreCredentials, + temp_dir: Path | None, + wait_period: int, +) -> ty.Iterator[Path]: + """Iterate over sessions stored in an S3 bucket + + Parameters + ---------- + bucket_path : str + the path to the S3 bucket + store_credentials : StoreCredentials + the credentials to access the S3 bucket + temp_dir : Path, optional + the temporary directory to download the sessions to, by default None + wait_period : int + the number of seconds after the last write before considering a session complete + """ + # List sessions stored in s3 bucket + s3 = boto3.resource( + "s3", + aws_access_key_id=store_credentials.access_key, + aws_secret_access_key=store_credentials.access_secret, + ) + bucket_name, prefix = bucket_path[5:].split("/", 1) + bucket = s3.Bucket(bucket_name) + if not prefix.endswith("/"): + prefix += "/" + all_objects = bucket.objects.filter(Prefix=prefix) + session_objs = defaultdict(list) + for obj in all_objects: + if obj.key.endswith("/"): + continue # skip directories + path_parts = obj.key[len(prefix) :].split("/") + session_name = path_parts[0] + session_objs[session_name].append((path_parts[1:], obj)) + + num_sessions = len(session_objs) + # Bit of a hack to allow the caller to know how many sessions are in the bucket + # we yield the number of sessions as the first item in the iterator + yield num_sessions # type: ignore[misc] + + if temp_dir: + tmp_download_dir = temp_dir / "xnat-ingest-download" + tmp_download_dir.mkdir(parents=True, exist_ok=True) + else: + tmp_download_dir = Path(tempfile.mkdtemp()) + + for session_name, objs in session_objs.items(): + # Just in case the manifest file is not included in the list of objects + # we recreate the project/subject/sesssion directory structure + session_tmp_dir = tmp_download_dir / session_name + session_tmp_dir.mkdir(parents=True, exist_ok=True) + # Check to see if the session is still being updated + last_modified = None + for _, obj in objs: + if last_modified is None or obj.last_modified > last_modified: + last_modified = obj.last_modified + assert last_modified is not None + if (datetime.datetime.now() - last_modified) >= datetime.timedelta( + seconds=wait_period + ): + for relpath, obj in tqdm( + objs, + desc=f"Downloading scans in '{session_name}' session from S3 bucket", + ): + if last_modified is None or obj.last_modified > last_modified: + last_modified = obj.last_modified + obj_path = session_tmp_dir.joinpath(*relpath) + obj_path.parent.mkdir(parents=True, exist_ok=True) + logger.debug("Downloading %s to %s", obj, obj_path) + with open(obj_path, "wb") as f: + bucket.download_fileobj(obj.key, f) + yield session_tmp_dir + else: + logger.info( + "Skipping session '%s' as it was last modified less than %d seconds ago " + "and waiting until it is complete", + session_name, + wait_period, + ) + shutil.rmtree(session_tmp_dir) # Delete the tmp session after the upload + + logger.info("Found %d sessions in S3 bucket '%s'", num_sessions, bucket_path) + logger.debug("Created sessions iterator") + + +def remove_old_files_on_s3(remote_store: str, threshold: int) -> None: + # Parse S3 bucket and prefix from remote store + bucket_name, prefix = remote_store[5:].split("/", 1) + + # Create S3 client + s3_client = boto3.client("s3") + + # List objects in the bucket with the specified prefix + response = s3_client.list_objects_v2(Bucket=bucket_name, Prefix=prefix) + + now = datetime.datetime.now() + + # Iterate over objects and delete files older than the threshold + for obj in response.get("Contents", []): + last_modified = obj["LastModified"] + age = (now - last_modified).days + if age > threshold: + s3_client.delete_object(Bucket=bucket_name, Key=obj["Key"]) + + +def remove_old_files_on_ssh(remote_store: str, threshold: int) -> None: + # Parse SSH server and directory from remote store + server, directory = remote_store.split("@", 1) + + # Create SSH client + ssh_client = paramiko.SSHClient() + ssh_client.load_system_host_keys() + ssh_client.connect(server) + + # Execute find command to list files in the directory + stdin, stdout, stderr = ssh_client.exec_command(f"find {directory} -type f") + + now = datetime.datetime.now() + + # Iterate over files and delete files older than the threshold + for file_path in stdout.read().decode().splitlines(): + last_modified = datetime.datetime.fromtimestamp(os.path.getmtime(file_path)) + age = (now - last_modified).days + if age > threshold: + ssh_client.exec_command(f"rm {file_path}") + + ssh_client.close() + + +def get_xnat_session(session: ImagingSession, xproject: ty.Any) -> ty.Any: + """Get the XNAT session object for the given session + + Parameters + ---------- + session : ImagingSession + the session to upload + xnat_repo : Xnat + the XNAT repository to upload to + + Returns + ------- + xsession : ty.Any + the XNAT session object + """ + xclasses = xproject.xnat_session.classes + + xsubject = xclasses.SubjectData(label=session.subject_id, parent=xproject) + try: + xsession = xproject.experiments[session.session_id] + except KeyError: + if "MR" in session.modalities: + SessionClass = xclasses.MrSessionData + elif "PT" in session.modalities: + SessionClass = xclasses.PetSessionData + elif "CT" in session.modalities: + SessionClass = xclasses.CtSessionData + else: + raise RuntimeError( + "Found the following unsupported modalities in " + f"{session.name}: {session.modalities}" + ) + xsession = SessionClass(label=session.session_id, parent=xsubject) + return xsession + + +def get_xnat_resource(resource: ImagingResource, xsession: ty.Any) -> ty.Any: + """Get the XNAT resource object for the given resource + + Parameters + ---------- + resource : ImagingResource + the resource to upload + xsession : ty.Any + the XNAT session object + + Returns + ------- + xresource : ty.Any + the XNAT resource object + """ + xclasses = xsession.xnat_session.classes + try: + xscan = xsession.scans[resource.scan.id] + except KeyError: + if isinstance(xsession, xclasses.MrSessionData): + default_scan_modality = "MR" + elif isinstance(xsession, xclasses.PetSessionData): + default_scan_modality = "PT" + else: + default_scan_modality = "CT" + if resource.metadata: + image_type = resource.metadata.get("ImageType") + if image_type and image_type[:2] == [ + "DERIVED", + "SECONDARY", + ]: + modality = "SC" + resource_name = "secondary" + else: + modality = resource.metadata.get("Modality", default_scan_modality) + else: + modality = default_scan_modality + if modality == "SC": + ScanClass = xclasses.ScScanData + elif modality == "MR": + ScanClass = xclasses.MrScanData + elif modality == "PT": + ScanClass = xclasses.PetScanData + elif modality == "CT": + ScanClass = xclasses.CtScanData + else: + SessionClass = type(xsession) + if SessionClass is xclasses.PetSessionData: + ScanClass = xclasses.PetScanData + elif SessionClass is xclasses.CtSessionData: + ScanClass = xclasses.CtScanData + else: + ScanClass = xclasses.MrScanData + logger.info( + "Can't determine modality of %s-%s scan, defaulting to the " + "default for %s sessions, %s", + resource.scan.id, + resource.scan.type, + SessionClass, + ScanClass, + ) + logger.debug( + "Creating scan %s in %s", resource.scan.id, resource.scan.session.path + ) + xscan = ScanClass( + id=resource.scan.id, + type=resource.scan.type, + parent=xsession, + ) + try: + xresource = xscan.resources[resource.name] + except KeyError: + pass + else: + checksums = get_xnat_checksums(xresource) + if checksums == resource.checksums: + logger.info( + "Skipping '%s' resource in '%s' as it " "already exists on XNAT", + resource.name, + resource.scan.path, + ) + else: + difference = { + k: (v, resource.checksums[k]) + for k, v in checksums.items() + if v != resource.checksums[k] + } + logger.error( + "'%s' resource in '%s' already exists on XNAT with " + "different checksums. Please delete on XNAT to overwrite:\n%s", + resource.name, + resource.scan.path, + pprint.pformat(difference), + ) + return None + logger.debug( + "Creating resource %s in %s", + resource.name, + resource.scan.path, + ) + xresource = xscan.create_resource(resource.name) + return xresource + + +def get_xnat_checksums(xresource: ty.Any) -> dict[str, str]: + """ + Downloads the MD5 digests associated with the files in a resource. + + Parameters + ---------- + xresource : xnat.classes.Resource + XNAT resource to retrieve the checksums from + + Returns + ------- + dict[str, str] + the checksums calculated by XNAT + """ + result = xresource.xnat_session.get(xresource.uri + "/files") + if result.status_code != 200: + raise RuntimeError( + "Could not download metadata for resource {}. Files " + "may have been uploaded but cannot check checksums".format(xresource.id) + ) + return dict((r["Name"], r["digest"]) for r in result.json()["ResultSet"]["Result"]) + + +def calculate_checksums(scan: FileSet) -> ty.Dict[str, str]: + """ + Calculates the MD5 digests associated with the files in a fileset. + + Parameters + ---------- + scan : FileSet + the file-set to calculate the checksums for + + Returns + ------- + dict[str, str] + the calculated checksums + """ + checksums = {} + for fspath in scan.fspaths: + try: + hsh = hashlib.md5() + with open(fspath, "rb") as f: + for chunk in iter(lambda: f.read(HASH_CHUNK_SIZE), b""): + hsh.update(chunk) + checksum = hsh.hexdigest() + except OSError: + raise RuntimeError(f"Could not create digest of '{fspath}' ") + checksums[str(fspath.relative_to(scan.parent))] = checksum + return checksums + + +HASH_CHUNK_SIZE = 2**20 + + +def dir_older_than(path: Path, period: int) -> bool: + """ + Get the most recent modification time of a directory and its contents. + + Parameters + ---------- + path : Path + the directory to get the modification time of + period : int + the number of seconds after the last modification time to check against + + Returns + ------- + bool + whether the directory is older than the specified period + """ + mtimes = [path.stat().st_mtime] + for root, _, files in os.walk(path): + for file in files: + mtimes.append((Path(root) / file).stat().st_mtime) + last_modified = datetime.datetime.fromtimestamp(max(mtimes)) + return (datetime.datetime.now() - last_modified) >= datetime.timedelta( + seconds=period + ) diff --git a/xnat_ingest/utils.py b/xnat_ingest/utils.py index c7d3009..f62e584 100644 --- a/xnat_ingest/utils.py +++ b/xnat_ingest/utils.py @@ -5,11 +5,11 @@ from pathlib import Path import sys import typing as ty -import hashlib import attrs import click.types +import click.testing +import discord from fileformats.core import DataType, FileSet, from_mime -from .dicom import DicomField # noqa logger = logging.getLogger("xnat-ingest") @@ -24,10 +24,10 @@ def datatype_converter( class classproperty(object): - def __init__(self, f): + def __init__(self, f: ty.Callable[..., ty.Any]) -> None: self.f = f - def __get__(self, obj, owner): + def __get__(self, obj: object, owner: ty.Any) -> ty.Any: return self.f(owner) @@ -37,82 +37,64 @@ class CliType(click.types.ParamType): def __init__( self, - type_, - multiple=False, + type_: ty.Type[ty.Union["CliTyped", "MultiCliTyped"]], + multiple: bool = False, ): self.type = type_ self.multiple = multiple def convert( self, value: ty.Any, param: click.Parameter | None, ctx: click.Context | None - ): + ) -> ty.Any: if isinstance(value, self.type): return value return self.type(*value) @property - def arity(self): + def arity(self) -> int: # type: ignore[override] return len(attrs.fields(self.type)) @property - def name(self): + def name(self) -> str: # type: ignore[override] return type(self).__name__.lower() - def split_envvar_value(self, envvar): + def split_envvar_value(self, envvar: str) -> ty.Any: if self.multiple: return [self.type(*entry.split(",")) for entry in envvar.split(";")] else: return self.type(*envvar.split(",")) +@attrs.define class CliTyped: @classproperty - def cli_type(cls): - return CliType(cls) + def cli_type(cls) -> CliType: + return CliType(cls) # type: ignore[arg-type] +@attrs.define class MultiCliTyped: @classproperty - def cli_type(cls): - return CliType(cls, multiple=True) - + def cli_type(cls) -> CliType: + return CliType(cls, multiple=True) # type: ignore[arg-type] -@attrs.define -class LogEmail(CliTyped): - - address: str - loglevel: str - subject: str - def __str__(self): - return self.address +def to_upper(value: str) -> str: + return value.upper() @attrs.define -class LogFile(MultiCliTyped): +class LoggerConfig(MultiCliTyped): - path: Path = attrs.field(converter=Path) + type: str loglevel: str + location: str - def __bool__(self): - return bool(self.path) - - def __str__(self): - return str(self.path) - - def __fspath__(self): - return str(self.path) - - -@attrs.define -class MailServer(CliTyped): - - host: str - sender_email: str - user: str - password: str + @property + def loglevel_int(self) -> int: + return getattr(logging, self.loglevel.upper()) # type: ignore[no-any-return] @attrs.define @@ -139,130 +121,58 @@ class StoreCredentials(CliTyped): def set_logger_handling( - log_level: str, - log_emails: ty.List[LogEmail] | None, - log_files: ty.List[LogFile] | None, - mail_server: MailServer, - add_logger: ty.Sequence[str] = (), -): + logger_configs: ty.Sequence[LoggerConfig], + additional_loggers: ty.Sequence[str] = (), +) -> None: + """Set up logging for the application""" loggers = [logger] - for log in add_logger: + for log in additional_loggers: loggers.append(logging.getLogger(log)) - levels = [log_level] - if log_emails: - levels.extend(le.loglevel for le in log_emails) - if log_files: - levels.extend(lf.loglevel for lf in log_files) - - min_log_level = min(getattr(logging, ll.upper()) for ll in levels) + min_log_level = min(ll.loglevel_int for ll in logger_configs) for logr in loggers: logr.setLevel(min_log_level) - # Configure the email logger - if log_emails: - if not mail_server: - raise ValueError( - "Mail server needs to be provided, either by `--mail-server` option or " - "XNAT_INGEST_MAILSERVER environment variable if logger emails " - "are provided: " + ", ".join(str(le) for le in log_emails) - ) - for log_email in log_emails: - smtp_hdle = logging.handlers.SMTPHandler( - mailhost=mail_server.host, - fromaddr=mail_server.sender_email, - toaddrs=[log_email.address], - subject=log_email.subject, - credentials=(mail_server.user, mail_server.password), - secure=None, - ) - smtp_hdle.setLevel(getattr(logging, log_email.loglevel.upper())) - for logr in loggers: - logr.addHandler(smtp_hdle) - # Configure the file logger - if log_files: - for log_file in log_files: - log_file.path.parent.mkdir(exist_ok=True) - log_file_hdle = logging.FileHandler(log_file) - if log_file.loglevel: - log_file_hdle.setLevel(getattr(logging, log_file.loglevel.upper())) - log_file_hdle.setFormatter( - logging.Formatter( - "%(asctime)s - %(name)s - %(levelname)s - %(message)s" - ) - ) - for logr in loggers: - logr.addHandler(log_file_hdle) - - console_hdle = logging.StreamHandler(sys.stdout) - console_hdle.setLevel(getattr(logging, log_level.upper())) - console_hdle.setFormatter( - logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") - ) - for logr in loggers: - logr.addHandler(console_hdle) - - -def get_checksums(xresource) -> ty.Dict[str, str]: - """ - Downloads the MD5 digests associated with the files in a resource. - - Parameters - ---------- - xresource : xnat.classes.Resource - XNAT resource to retrieve the checksums from - - Returns - ------- - dict[str, str] - the checksums calculated by XNAT - """ - result = xresource.xnat_session.get(xresource.uri + "/files") - if result.status_code != 200: - raise RuntimeError( - "Could not download metadata for resource {}. Files " - "may have been uploaded but cannot check checksums".format(xresource.id) + for config in logger_configs: + log_handle: logging.Handler + if config.type == "file": + Path(config.location).parent.mkdir(parents=True, exist_ok=True) + log_handle = logging.FileHandler(config.location) + elif config.type == "stream": + stream = sys.stderr if config.location == "stderr" else sys.stdout + log_handle = logging.StreamHandler(stream) + elif config.type == "discord": + log_handle = DiscordHandler(config.location) + else: + raise ValueError(f"Unknown logger type: {config.type}") + log_handle.setLevel(config.loglevel_int) + log_handle.setFormatter( + logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") ) - return dict((r["Name"], r["digest"]) for r in result.json()["ResultSet"]["Result"]) + for logr in loggers: + logr.addHandler(log_handle) -def calculate_checksums(scan: FileSet) -> ty.Dict[str, str]: - """ - Calculates the MD5 digests associated with the files in a fileset. - - Parameters - ---------- - scan : FileSet - the file-set to calculate the checksums for - - Returns - ------- - dict[str, str] - the calculated checksums - """ - checksums = {} - for fspath in scan.fspaths: - try: - hsh = hashlib.md5() - with open(fspath, "rb") as f: - for chunk in iter(lambda: f.read(HASH_CHUNK_SIZE), b""): - hsh.update(chunk) - checksum = hsh.hexdigest() - except OSError: - raise RuntimeError(f"Could not create digest of '{fspath}' ") - checksums[str(fspath.relative_to(scan.parent))] = checksum - return checksums +def show_cli_trace(result: click.testing.Result) -> str: + """Show the exception traceback from CLIRunner results""" + assert result.exc_info is not None + exc_type, exc, tb = result.exc_info + return "".join(traceback.format_exception(exc_type, value=exc, tb=tb)) -HASH_CHUNK_SIZE = 2**20 +class DiscordHandler(logging.Handler): + """A logging handler that sends log messages to a Discord webhook""" + def __init__(self, webhook_url: str): + super().__init__() + self.webhook_url = webhook_url + self.client = discord.Webhook.from_url(webhook_url) -def show_cli_trace(result): - """Show the exception traceback from CLIRunner results""" - return "".join(traceback.format_exception(*result.exc_info)) + def emit(self, record: logging.LogRecord) -> None: + self.client.send(record.msg) class RegexExtractor: @@ -289,7 +199,7 @@ def __call__(self, to_match: str) -> str: return extracted -def add_exc_note(e, note): +def add_exc_note(e: Exception, note: str) -> Exception: """Adds a note to an exception in a Python <3.11 compatible way Parameters @@ -352,7 +262,7 @@ def transform_paths( group_count: Counter[str] = Counter() # Create regex groups for string template args - def str_templ_to_regex_group(match) -> str: + def str_templ_to_regex_group(match: re.Match[str]) -> str: fieldname = match.group(0)[1:-1] if "." in fieldname: fieldname, attr_name = fieldname.split(".") @@ -374,7 +284,8 @@ def str_templ_to_regex_group(match) -> str: transform_path_re = re.compile(transform_path_pattern + "$") # Define a custom replacement function - def replace_named_groups(match): + def replace_named_groups(match: re.Match[str]) -> str: + assert match.lastgroup is not None return new_values.get(match.lastgroup, match.group()) transformed = [] @@ -432,22 +343,22 @@ def glob_to_re(glob_pattern: str) -> str: # W/o leading or trailing ``/`` two consecutive asterisks will be treated as literals. # Edge-case #1. Catches recursive globs in the middle of path. Requires edge # case #2 handled after this case. - ("/\*\*", "(?:/.+?)*"), + (r"/\*\*", "(?:/.+?)*"), # Edge-case #2. Catches recursive globs at the start of path. Requires edge # case #1 handled before this case. ``^`` is used to ensure proper location for ``**/``. - ("\*\*/", "(?:^.+?/)*"), + (r"\*\*/", "(?:^.+?/)*"), # ``[^/]*`` is used to ensure that ``*`` won't match subdirs, as with naive # ``.*?`` solution. - ("\*", "[^/]*"), - ("\?", "."), - ("\[\*\]", "\*"), # Escaped special glob character. - ("\[\?\]", "\?"), # Escaped special glob character. + (r"\*", "[^/]*"), + (r"\?", "."), + (r"\[\*\]", r"\*"), # Escaped special glob character. + (r"\[\?\]", r"\?"), # Escaped special glob character. # Requires ordered dict, so that ``\[!`` preceded ``\[`` in RE pattern. Needed # mostly to differentiate between ``!`` used within character class ``[]`` and # outside of it, to avoid faulty conversion. - ("\[!", "[^"), - ("\[", "["), - ("\]", "]"), + (r"\[!", "[^"), + (r"\[", "["), + (r"\]", "]"), ) ) @@ -456,3 +367,5 @@ def glob_to_re(glob_pattern: str) -> str: ) _str_templ_replacement = re.compile(r"\{[\w\.]+\}") + +invalid_path_chars_re = re.compile(r'[<>:"/\\|?*\x00-\x1F]')