From 9ba91a0b721f4e289fe79bb17c6672fcd9ae34b3 Mon Sep 17 00:00:00 2001 From: Michael Terry Date: Wed, 4 Sep 2024 10:11:41 -0400 Subject: [PATCH] Parse deleted/*.ndjson files from a bulk export and delete records - When loading using the default ndjson loader, we look for a deleted/ folder and read any Bundle files there for deleted IDs - And then pass that along to tasks and matching formatters - Formatters now have a delete_records(ids) call - If the output format is deltalake, the IDs will be deleted --- cumulus_etl/etl/cli.py | 1 + cumulus_etl/etl/config.py | 2 + cumulus_etl/etl/tasks/base.py | 38 +++++++- cumulus_etl/formats/base.py | 8 ++ cumulus_etl/formats/batched_files.py | 9 ++ cumulus_etl/formats/deltalake.py | 40 ++++++-- cumulus_etl/loaders/base.py | 4 + cumulus_etl/loaders/fhir/ndjson_loader.py | 27 ++++++ tests/etl/test_etl_cli.py | 18 ++++ tests/etl/test_tasks.py | 41 ++++++++ tests/formats/test_deltalake.py | 107 +++++++++++++++++++++ tests/loaders/ndjson/test_ndjson_loader.py | 44 +++++++++ 12 files changed, 326 insertions(+), 13 deletions(-) diff --git a/cumulus_etl/etl/cli.py b/cumulus_etl/etl/cli.py index 67d47f76..1d2504b9 100644 --- a/cumulus_etl/etl/cli.py +++ b/cumulus_etl/etl/cli.py @@ -296,6 +296,7 @@ async def etl_main(args: argparse.Namespace) -> None: tasks=[t.name for t in selected_tasks], export_group_name=export_group_name, export_datetime=export_datetime, + deleted_ids=loader_results.deleted_ids, ) common.write_json(config.path_config(), config.as_json(), indent=4) diff --git a/cumulus_etl/etl/config.py b/cumulus_etl/etl/config.py index 9195ae2d..24918767 100644 --- a/cumulus_etl/etl/config.py +++ b/cumulus_etl/etl/config.py @@ -33,6 +33,7 @@ def __init__( tasks: list[str] | None = None, export_group_name: str | None = None, export_datetime: datetime.datetime | None = None, + deleted_ids: dict[str, set[str]] | None = None, ): self._dir_input_orig = dir_input_orig self.dir_input = dir_input_deid @@ -50,6 +51,7 @@ def __init__( self.tasks = tasks or [] self.export_group_name = export_group_name self.export_datetime = export_datetime + self.deleted_ids = deleted_ids or {} # initialize format class self._output_root = store.Root(self._dir_output, create=True) diff --git a/cumulus_etl/etl/tasks/base.py b/cumulus_etl/etl/tasks/base.py index f4b39aeb..00b2d403 100644 --- a/cumulus_etl/etl/tasks/base.py +++ b/cumulus_etl/etl/tasks/base.py @@ -143,11 +143,17 @@ async def run(self) -> list[config.JobSummary]: with self._indeterminate_progress(progress, "Finalizing"): # Ensure that we touch every output table (to create them and/or to confirm schema). - # Consider case of Medication for an EHR that only has inline Medications inside MedicationRequest. - # The Medication table wouldn't get created otherwise. Plus this is a good place to push any schema - # changes. (The reason it's nice if the table & schema exist is so that downstream SQL can be dumber.) + # Consider case of Medication for an EHR that only has inline Medications inside + # MedicationRequest. + # The Medication table wouldn't get created otherwise. Plus this is a good place to + # push any schema changes. + # (The reason it's nice if the table & schema exist is so that downstream SQL can + # be dumber.) self._touch_remaining_tables() + # If the input data indicates we should delete some IDs, do that here. + self._delete_requested_ids() + # Mark this group & resource combo as complete self._update_completion_table() @@ -228,6 +234,32 @@ def _touch_remaining_tables(self): # just write an empty dataframe (should be fast) self._write_one_table_batch([], table_index, 0) + def _delete_requested_ids(self): + """ + Deletes IDs that have been marked for deletion. + + Formatters are expected to already exist when this is called. + + This usually happens via the `deleted` array from a bulk export. + Which clients usually drop in a deleted/ folder in the download directory. + But in our case, that's abstracted away into a JobConfig.deleted_ids dictionary. + """ + for index, output in enumerate(self.outputs): + resource = output.get_resource_type(self) + if not resource or resource.lower() != output.get_name(self): + # Only delete from the main table for the resource + continue + + deleted_ids = self.task_config.deleted_ids.get(resource, set()) + if not deleted_ids: + continue + + deleted_ids = { + self.scrubber.codebook.fake_id(resource, x, caching_allowed=False) + for x in deleted_ids + } + self.formatters[index].delete_records(deleted_ids) + def _update_completion_table(self) -> None: # TODO: what about empty sets - do we assume the export gave 0 results or skip it? # Is there a difference we could notice? (like empty input file vs no file at all) diff --git a/cumulus_etl/formats/base.py b/cumulus_etl/formats/base.py index 8f56dd46..7c56f8da 100644 --- a/cumulus_etl/formats/base.py +++ b/cumulus_etl/formats/base.py @@ -74,6 +74,14 @@ def _write_one_batch(self, batch: Batch) -> None: :param batch: the batch of data """ + @abc.abstractmethod + def delete_records(self, ids: set[str]) -> None: + """ + Deletes all mentioned IDs from the table. + + :param ids: all IDs to remove + """ + def finalize(self) -> None: """ Performs any necessary cleanup after all batches have been written. diff --git a/cumulus_etl/formats/batched_files.py b/cumulus_etl/formats/batched_files.py index 9369a195..25e45882 100644 --- a/cumulus_etl/formats/batched_files.py +++ b/cumulus_etl/formats/batched_files.py @@ -83,3 +83,12 @@ def _write_one_batch(self, batch: Batch) -> None: full_path = self.dbroot.joinpath(f"{self.dbname}.{self._index:03}.{self.suffix}") self.write_format(batch, full_path) self._index += 1 + + def delete_records(self, ids: set[str]) -> None: + """ + Deletes the given IDs. + + Though this is a no-op for batched file outputs, since: + - we guarantee the output folder is empty at the start + - the spec says deleted IDs won't overlap with output IDs + """ diff --git a/cumulus_etl/formats/deltalake.py b/cumulus_etl/formats/deltalake.py index 1e8fc60d..06aa2864 100644 --- a/cumulus_etl/formats/deltalake.py +++ b/cumulus_etl/formats/deltalake.py @@ -91,8 +91,6 @@ def initialize_class(cls, root: store.Root) -> None: def _write_one_batch(self, batch: Batch) -> None: """Writes the whole dataframe to a delta lake""" with self.batch_to_spark(batch) as updates: - if updates is None: - return delta_table = self.update_delta_table(updates, groups=batch.groups) delta_table.generate("symlink_format_manifest") @@ -131,16 +129,25 @@ def update_delta_table( return table - def finalize(self) -> None: - """Performs any necessary cleanup after all batches have been written""" - full_path = self._table_path(self.dbname) + def delete_records(self, ids: set[str]) -> None: + """Deletes the given IDs.""" + if not ids: + return + + table = self._load_table() + if not table: + return try: - table = delta.DeltaTable.forPath(self.spark, full_path) - except AnalysisException: - return # if the table doesn't exist because we didn't write anything, that's fine - just bail + id_list = "', '".join(ids) + table.delete(f"id in ('{id_list}')") except Exception: - logging.exception("Could not finalize Delta Lake table %s", self.dbname) + logging.exception("Could not delete IDs from Delta Lake table %s", self.dbname) + + def finalize(self) -> None: + """Performs any necessary cleanup after all batches have been written""" + table = self._load_table() + if not table: return try: @@ -154,6 +161,19 @@ def _table_path(self, dbname: str) -> str: # hadoop uses the s3a: scheme instead of s3: return self.root.joinpath(dbname).replace("s3://", "s3a://") + def _load_table(self) -> delta.DeltaTable | None: + full_path = self._table_path(self.dbname) + + try: + return delta.DeltaTable.forPath(self.spark, full_path) + except AnalysisException: + # The table likely doesn't exist. + # Which can be normal if we didn't write anything yet, that's fine - just bail. + return None + except Exception: + logging.exception("Could not load Delta Lake table %s", self.dbname) + return None + @staticmethod def _get_update_condition(schema: pyspark.sql.types.StructType) -> str | None: """ @@ -214,7 +234,7 @@ def _configure_fs(root: store.Root, spark: pyspark.sql.SparkSession): spark.conf.set("fs.s3a.endpoint.region", region_name) @contextlib.contextmanager - def batch_to_spark(self, batch: Batch) -> pyspark.sql.DataFrame | None: + def batch_to_spark(self, batch: Batch) -> pyspark.sql.DataFrame: """Transforms a batch to a spark DF""" # This is the quick and dirty way - write batch to parquet with pyarrow and read it back. # But a more direct way would be to convert the pyarrow schema to a pyspark schema and just diff --git a/cumulus_etl/loaders/base.py b/cumulus_etl/loaders/base.py index 9db4a304..f773df0e 100644 --- a/cumulus_etl/loaders/base.py +++ b/cumulus_etl/loaders/base.py @@ -23,6 +23,10 @@ def path(self) -> str: group_name: str | None = None export_datetime: datetime.datetime | None = None + # A list of resource IDs that should be deleted from the output tables. + # This is a map of resource -> set of IDs like {"Patient": {"A", "B"}} + deleted_ids: dict[str, set[str]] = dataclasses.field(default_factory=dict) + class Loader(abc.ABC): """ diff --git a/cumulus_etl/loaders/fhir/ndjson_loader.py b/cumulus_etl/loaders/fhir/ndjson_loader.py index b35d7163..b4ef2bb6 100644 --- a/cumulus_etl/loaders/fhir/ndjson_loader.py +++ b/cumulus_etl/loaders/fhir/ndjson_loader.py @@ -62,6 +62,8 @@ async def load_all(self, resources: list[str]) -> base.LoaderResults: # For now, just ignore any errors. pass + results.deleted_ids = self.read_deleted_ids(input_root) + # Copy the resources we need from the remote directory (like S3 buckets) to a local one. # # We do this even if the files are local, because the next step in our pipeline is the MS deid tool, @@ -112,3 +114,28 @@ async def load_from_bulk_export( group_name=bulk_exporter.group_name, export_datetime=bulk_exporter.export_datetime, ) + + def read_deleted_ids(self, root: store.Root) -> dict[str, set[str]]: + """ + Reads any deleted IDs that a bulk export gave us. + + See https://hl7.org/fhir/uv/bulkdata/export.html for details. + """ + deleted_ids = {} + + subdir = store.Root(root.joinpath("deleted")) + bundles = common.read_resource_ndjson(subdir, "Bundle") + for bundle in bundles: + if bundle.get("type") != "transaction": + continue + for entry in bundle.get("entry", []): + request = entry.get("request", {}) + if request.get("method") != "DELETE": + continue + url = request.get("url") # should be relative URL like "Patient/123" + if not url or url.count("/") != 1: + continue + resource, res_id = url.split("/") + deleted_ids.setdefault(resource, set()).add(res_id) + + return deleted_ids diff --git a/tests/etl/test_etl_cli.py b/tests/etl/test_etl_cli.py index 30e6394e..7c67db2f 100644 --- a/tests/etl/test_etl_cli.py +++ b/tests/etl/test_etl_cli.py @@ -287,6 +287,24 @@ async def test_completion_args(self, etl_args, loader_vals, expected_vals): self.assertEqual(expected_vals[0], config.export_group_name) self.assertEqual(expected_vals[1], config.export_datetime) + async def test_deleted_ids_passed_down(self): + """Verify that we parse pass along any deleted ids to the JobConfig.""" + with tempfile.TemporaryDirectory() as tmpdir: + results = loaders.LoaderResults( + directory=common.RealDirectory(tmpdir), deleted_ids={"Observation": {"obs1"}} + ) + + with ( + self.assertRaises(SystemExit) as cm, + mock.patch("cumulus_etl.etl.cli.etl_job", side_effect=SystemExit) as mock_etl_job, + mock.patch.object(loaders.FhirNdjsonLoader, "load_all", return_value=results), + ): + await self.run_etl(tasks=["observation"]) + + self.assertEqual(mock_etl_job.call_count, 1) + config = mock_etl_job.call_args[0][0] + self.assertEqual({"Observation": {"obs1"}}, config.deleted_ids) + class TestEtlJobConfig(BaseEtlSimple): """Test case for the job config logging data""" diff --git a/tests/etl/test_tasks.py b/tests/etl/test_tasks.py index 7bc406be..b3ac4085 100644 --- a/tests/etl/test_tasks.py +++ b/tests/etl/test_tasks.py @@ -7,6 +7,7 @@ import ddt from cumulus_etl import common, errors +from cumulus_etl.etl import tasks from cumulus_etl.etl.tasks import basic_tasks, task_factory from tests.etl import TaskTestCase @@ -133,6 +134,46 @@ async def test_batch_is_given_schema(self): self.assertIn("address", schema.names) self.assertIn("id", schema.names) + async def test_get_schema(self): + """Verify that Task.get_schema() works for resources and non-resources""" + schema = tasks.EtlTask.get_schema("Patient", []) + self.assertIn("gender", schema.names) + schema = tasks.EtlTask.get_schema(None, []) + self.assertIsNone(schema) + + async def test_prepare_can_skip_task(self): + """Verify that if prepare_task returns false, we skip the task""" + self.make_json("Patient", "A") + with mock.patch( + "cumulus_etl.etl.tasks.basic_tasks.PatientTask.prepare_task", return_value=False + ): + summaries = await basic_tasks.PatientTask(self.job_config, self.scrubber).run() + self.assertEqual(len(summaries), 1) + self.assertEqual(summaries[0].attempt, 0) + self.assertIsNone(self.format) + + async def test_deleted_ids_no_op(self): + """Verify that we don't try to delete IDs if none are given""" + # Just a simple test to confirm we don't even ask the formatter to consider + # deleting any IDs if we weren't given any. + await basic_tasks.PatientTask(self.job_config, self.scrubber).run() + self.assertEqual(self.format.delete_records.call_count, 0) + + async def test_deleted_ids(self): + """Verify that we send deleted IDs down to the formatter""" + self.job_config.deleted_ids = {"Patient": {"p1", "p2"}} + await basic_tasks.PatientTask(self.job_config, self.scrubber).run() + + self.assertEqual(self.format.delete_records.call_count, 1) + ids = self.format.delete_records.call_args[0][0] + self.assertEqual( + ids, + { + self.codebook.db.resource_hash("p1"), + self.codebook.db.resource_hash("p2"), + }, + ) + @ddt.ddt class TestTaskCompletion(TaskTestCase): diff --git a/tests/formats/test_deltalake.py b/tests/formats/test_deltalake.py index e1cfb89b..4bfa3aad 100644 --- a/tests/formats/test_deltalake.py +++ b/tests/formats/test_deltalake.py @@ -5,6 +5,7 @@ import os import shutil import tempfile +from unittest import mock import ddt import pyarrow @@ -395,3 +396,109 @@ def test_update_existing(self): self.store(self.df(a=1, b=2)) self.store(self.df(a=999, c=3), update_existing=False) self.assert_lake_equal(self.df(a=1, b=2, c=3)) + + def test_s3_options(self): + """Verify that we read in S3 options and set spark config appropriately""" + # Save global/class-wide spark object, to be restored. Then clear it out. + old_spark = DeltaLakeFormat.spark + + def restore_spark(): + DeltaLakeFormat.spark = old_spark + + self.addCleanup(restore_spark) + DeltaLakeFormat.spark = None + + # Now re-initialize the class, mocking out all the slow spark stuff, and using S3. + fs_options = { + "s3_kms_key": "test-key", + "s3_region": "us-west-1", + } + with ( + mock.patch("cumulus_etl.store._user_fs_options", new=fs_options), + mock.patch("delta.configure_spark_with_delta_pip"), + mock.patch("pyspark.sql"), + ): + DeltaLakeFormat.initialize_class(store.Root("s3://test/")) + + self.assertEqual( + sorted(DeltaLakeFormat.spark.conf.set.call_args_list, key=lambda x: x[0][0]), + [ + mock.call( + "fs.s3a.aws.credentials.provider", + "com.amazonaws.auth.DefaultAWSCredentialsProviderChain", + ), + mock.call("fs.s3a.endpoint.region", "us-west-1"), + mock.call("fs.s3a.server-side-encryption-algorithm", "SSE-KMS"), + mock.call("fs.s3a.server-side-encryption.key", "test-key"), + mock.call("fs.s3a.sse.enabled", "true"), + ], + ) + + def test_finalize_happy_path(self): + """Verify that we clean up the delta lake when finalizing.""" + # Limit our fake table to just these attributes, to notice any new usage in future + mock_table = mock.MagicMock(spec=["generate", "optimize", "vacuum"]) + self.patch("delta.DeltaTable.forPath", return_value=mock_table) + + DeltaLakeFormat(self.root, "patient").finalize() + self.assertEqual(mock_table.optimize.call_args_list, [mock.call()]) + self.assertEqual( + mock_table.optimize.return_value.executeCompaction.call_args_list, [mock.call()] + ) + self.assertEqual(mock_table.generate.call_args_list, [mock.call("symlink_format_manifest")]) + self.assertEqual(mock_table.vacuum.call_args_list, [mock.call()]) + + def test_finalize_cannot_load_table(self): + """Verify that we gracefully handle failing to read an existing table when finalizing.""" + # No table + deltalake = DeltaLakeFormat(self.root, "patient") + with self.assertNoLogs(): + deltalake.finalize() + self.assertFalse(os.path.exists(self.output_dir)) + + # Error loading the table + with self.assertLogs(level="ERROR") as logs: + with mock.patch("delta.DeltaTable.forPath", side_effect=ValueError): + deltalake.finalize() + self.assertEqual(len(logs.output), 1) + self.assertTrue( + logs.output[0].startswith("ERROR:root:Could not load Delta Lake table patient\n") + ) + + def test_finalize_error(self): + """Verify that we gracefully handle an error while finalizing.""" + self.store(self.df(a=1)) # create a simple table to load + with self.assertLogs(level="ERROR") as logs: + with mock.patch("delta.DeltaTable.optimize", side_effect=ValueError): + DeltaLakeFormat(self.root, "patient").finalize() + self.assertEqual(len(logs.output), 1) + self.assertTrue( + logs.output[0].startswith("ERROR:root:Could not finalize Delta Lake table patient\n") + ) + + def test_delete_records_happy_path(self): + """Verify that `delete_records` works in a basic way.""" + self.store(self.df(a=1, b=2, c=3, d=4)) + + deltalake = DeltaLakeFormat(self.root, "patient") + deltalake.delete_records({"a", "c"}) + deltalake.delete_records({"d"}) + deltalake.delete_records(set()) + + self.assert_lake_equal(self.df(b=2)) + + def test_delete_records_error(self): + """Verify that `delete_records` handles errors gracefully.""" + mock_table = mock.MagicMock(spec=["delete"]) + mock_table.delete.side_effect = ValueError + self.patch("delta.DeltaTable.forPath", return_value=mock_table) + + with self.assertLogs(level="ERROR") as logs: + DeltaLakeFormat(self.root, "patient").delete_records("a") + + self.assertEqual(len(logs.output), 1) + self.assertTrue( + logs.output[0].startswith( + "ERROR:root:Could not delete IDs from Delta Lake table patient\n" + ) + ) diff --git a/tests/loaders/ndjson/test_ndjson_loader.py b/tests/loaders/ndjson/test_ndjson_loader.py index 80dc352a..ea88e143 100644 --- a/tests/loaders/ndjson/test_ndjson_loader.py +++ b/tests/loaders/ndjson/test_ndjson_loader.py @@ -350,3 +350,47 @@ async def test_export_to_folder_not_local(self): with self.assertRaises(SystemExit) as cm: await loader.load_all([]) self.assertEqual(cm.exception.code, errors.BULK_EXPORT_FOLDER_NOT_LOCAL) + + async def test_reads_deleted_ids(self): + """Verify we read in the deleted/ folder""" + with tempfile.TemporaryDirectory() as tmpdir: + os.mkdir(f"{tmpdir}/deleted") + common.write_json( + f"{tmpdir}/deleted/deletes.ndjson", + { + "resourceType": "Bundle", + "type": "transaction", + "entry": [ + {"request": {"method": "GET", "url": "Patient/bad-method"}}, + {"request": {"method": "DELETE", "url": "Patient/pat1"}}, + {"request": {"method": "DELETE", "url": "Patient/too/many/slashes"}}, + {"request": {"method": "DELETE", "url": "Condition/con1"}}, + {"request": {"method": "DELETE", "url": "Condition/con2"}}, + ], + }, + ) + # This next bundle will be ignored because of the wrong "type" + common.write_json( + f"{tmpdir}/deleted/messages.ndjson", + { + "resourceType": "Bundle", + "type": "message", + "entry": [ + { + "request": {"method": "DELETE", "url": "Patient/wrong-message-type"}, + } + ], + }, + ) + # This next file will be ignored because of the wrong "resourceType" + common.write_json( + f"{tmpdir}/deleted/conditions-for-some-reason.ndjson", + { + "resourceType": "Condition", + "recordedDate": "2024-09-04", + }, + ) + loader = loaders.FhirNdjsonLoader(store.Root(tmpdir)) + results = await loader.load_all(["Patient"]) + + self.assertEqual(results.deleted_ids, {"Patient": {"pat1"}, "Condition": {"con1", "con2"}})