diff --git a/cumulus_etl/etl/cli.py b/cumulus_etl/etl/cli.py index 1b4b8994..40563aab 100644 --- a/cumulus_etl/etl/cli.py +++ b/cumulus_etl/etl/cli.py @@ -6,7 +6,6 @@ import os import shutil import sys -import tempfile from collections.abc import Iterable import rich @@ -24,21 +23,6 @@ ############################################################################### -async def load_and_deidentify(loader: loaders.Loader, resources: Iterable[str]) -> tempfile.TemporaryDirectory: - """ - Loads the input directory and does a first-pass de-identification - - Code outside this method should never see the original input files. - - :returns: a temporary directory holding the de-identified files in FHIR ndjson format - """ - # First step is loading all the data into a local ndjson format - loaded_dir = await loader.load_all(list(resources)) - - # Second step is de-identifying that data (at a bulk level) - return await deid.Scrubber.scrub_bulk_data(loaded_dir.name) - - async def etl_job( config: JobConfig, selected_tasks: list[type[tasks.EtlTask]], use_philter: bool = False ) -> list[JobSummary]: @@ -238,13 +222,20 @@ async def etl_main(args: argparse.Namespace) -> None: root_input, client=client, export_to=args.export_to, since=args.since, until=args.until ) - # Pull down resources and run the MS tool on them - deid_dir = await load_and_deidentify(config_loader, required_resources) + # Pull down resources from any remote location (like s3), convert from i2b2, or do a bulk export + loaded_dir = await config_loader.load_all(list(required_resources)) + + # If *any* of our tasks need bulk MS de-identification, run it + if any(t.needs_bulk_deid for t in selected_tasks): + loaded_dir = await deid.Scrubber.scrub_bulk_data(loaded_dir.name) + else: + print("Skipping bulk de-identification.") + print("These selected tasks will de-identify resources as they are processed.") # Prepare config for jobs config = JobConfig( args.dir_input, - deid_dir.name, + loaded_dir.name, args.dir_output, args.dir_phi, args.input_format, diff --git a/cumulus_etl/etl/studies/covid_symptom/covid_tasks.py b/cumulus_etl/etl/studies/covid_symptom/covid_tasks.py index 288be675..0aad990e 100644 --- a/cumulus_etl/etl/studies/covid_symptom/covid_tasks.py +++ b/cumulus_etl/etl/studies/covid_symptom/covid_tasks.py @@ -64,6 +64,7 @@ class CovidSymptomNlpResultsTask(tasks.EtlTask): name = "covid_symptom__nlp_results" resource = "DocumentReference" tags = {"covid_symptom", "gpu"} + needs_bulk_deid = False outputs = [tasks.OutputTable(schema=None, group_field="docref_id")] def __init__(self, *args, **kwargs): diff --git a/cumulus_etl/etl/studies/hftest/hf_tasks.py b/cumulus_etl/etl/studies/hftest/hf_tasks.py index 749527e4..f9d85aeb 100644 --- a/cumulus_etl/etl/studies/hftest/hf_tasks.py +++ b/cumulus_etl/etl/studies/hftest/hf_tasks.py @@ -15,6 +15,7 @@ class HuggingFaceTestTask(tasks.EtlTask): name = "hftest__summary" resource = "DocumentReference" + needs_bulk_deid = False outputs = [tasks.OutputTable(schema=None)] # Task Version diff --git a/cumulus_etl/etl/tasks/base.py b/cumulus_etl/etl/tasks/base.py index 4640a456..c83e407c 100644 --- a/cumulus_etl/etl/tasks/base.py +++ b/cumulus_etl/etl/tasks/base.py @@ -75,6 +75,7 @@ class EtlTask: name: str = None # task & table name resource: str = None # incoming resource that this task operates on (will be included in bulk exports etc) tags: set[str] = [] + needs_bulk_deid = True # whether this task needs bulk MS tool de-id run on its inputs (NLP tasks usually don't) outputs: list[OutputTable] = [OutputTable()] diff --git a/tests/etl/test_etl_cli.py b/tests/etl/test_etl_cli.py index 43c217b9..1fcbda97 100644 --- a/tests/etl/test_etl_cli.py +++ b/tests/etl/test_etl_cli.py @@ -7,6 +7,7 @@ import tempfile from unittest import mock +import ddt import pytest from ctakesclient.typesystem import Polarity @@ -88,6 +89,7 @@ def assert_output_equal(self, folder: str): self.assert_etl_output_equal(os.path.join(self.data_dir, folder), self.output_path) +@ddt.ddt class TestEtlJobFlow(BaseEtlSimple): """Test case for the sequence of data through the system""" @@ -95,6 +97,19 @@ async def test_batched_output(self): await self.run_etl(batch_size=1) self.assert_output_equal("batched-output") + @ddt.data( + (["covid_symptom__nlp_results"], False), + (["patient"], True), + (["covid_symptom__nlp_results", "patient"], True), + ) + @ddt.unpack + async def test_ms_deid_skipped_if_not_needed(self, tasks: list[str], expected_ms_deid: bool): + with self.assertRaises(SystemExit): + with mock.patch("cumulus_etl.deid.Scrubber.scrub_bulk_data") as mock_deid: + with mock.patch("cumulus_etl.etl.cli.etl_job", side_effect=SystemExit): + await self.run_etl(tasks=tasks) + self.assertEqual(1 if expected_ms_deid else 0, mock_deid.call_count) + async def test_downloaded_phi_is_not_kept(self): """Verify we remove all downloaded PHI even if interrupted""" internal_phi_dir = None