Skip to content

Commit

Permalink
Parse deleted/*.ndjson files from a bulk export and delete records
Browse files Browse the repository at this point in the history
- When loading using the default ndjson loader, we look for a deleted/
  folder and read any Bundle files there for deleted IDs
- And then pass that along to tasks and matching formatters
- Formatters now have a delete_records(ids) call
- If the output format is deltalake, the IDs will be deleted
  • Loading branch information
mikix committed Sep 4, 2024
1 parent 3180644 commit 9ba91a0
Show file tree
Hide file tree
Showing 12 changed files with 326 additions and 13 deletions.
1 change: 1 addition & 0 deletions cumulus_etl/etl/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ async def etl_main(args: argparse.Namespace) -> None:
tasks=[t.name for t in selected_tasks],
export_group_name=export_group_name,
export_datetime=export_datetime,
deleted_ids=loader_results.deleted_ids,
)
common.write_json(config.path_config(), config.as_json(), indent=4)

Expand Down
2 changes: 2 additions & 0 deletions cumulus_etl/etl/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(
tasks: list[str] | None = None,
export_group_name: str | None = None,
export_datetime: datetime.datetime | None = None,
deleted_ids: dict[str, set[str]] | None = None,
):
self._dir_input_orig = dir_input_orig
self.dir_input = dir_input_deid
Expand All @@ -50,6 +51,7 @@ def __init__(
self.tasks = tasks or []
self.export_group_name = export_group_name
self.export_datetime = export_datetime
self.deleted_ids = deleted_ids or {}

# initialize format class
self._output_root = store.Root(self._dir_output, create=True)
Expand Down
38 changes: 35 additions & 3 deletions cumulus_etl/etl/tasks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,17 @@ async def run(self) -> list[config.JobSummary]:

with self._indeterminate_progress(progress, "Finalizing"):
# Ensure that we touch every output table (to create them and/or to confirm schema).
# Consider case of Medication for an EHR that only has inline Medications inside MedicationRequest.
# The Medication table wouldn't get created otherwise. Plus this is a good place to push any schema
# changes. (The reason it's nice if the table & schema exist is so that downstream SQL can be dumber.)
# Consider case of Medication for an EHR that only has inline Medications inside
# MedicationRequest.
# The Medication table wouldn't get created otherwise. Plus this is a good place to
# push any schema changes.
# (The reason it's nice if the table & schema exist is so that downstream SQL can
# be dumber.)
self._touch_remaining_tables()

# If the input data indicates we should delete some IDs, do that here.
self._delete_requested_ids()

# Mark this group & resource combo as complete
self._update_completion_table()

Expand Down Expand Up @@ -228,6 +234,32 @@ def _touch_remaining_tables(self):
# just write an empty dataframe (should be fast)
self._write_one_table_batch([], table_index, 0)

def _delete_requested_ids(self):
"""
Deletes IDs that have been marked for deletion.
Formatters are expected to already exist when this is called.
This usually happens via the `deleted` array from a bulk export.
Which clients usually drop in a deleted/ folder in the download directory.
But in our case, that's abstracted away into a JobConfig.deleted_ids dictionary.
"""
for index, output in enumerate(self.outputs):
resource = output.get_resource_type(self)
if not resource or resource.lower() != output.get_name(self):
# Only delete from the main table for the resource
continue

deleted_ids = self.task_config.deleted_ids.get(resource, set())
if not deleted_ids:
continue

deleted_ids = {
self.scrubber.codebook.fake_id(resource, x, caching_allowed=False)
for x in deleted_ids
}
self.formatters[index].delete_records(deleted_ids)

def _update_completion_table(self) -> None:
# TODO: what about empty sets - do we assume the export gave 0 results or skip it?
# Is there a difference we could notice? (like empty input file vs no file at all)
Expand Down
8 changes: 8 additions & 0 deletions cumulus_etl/formats/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,14 @@ def _write_one_batch(self, batch: Batch) -> None:
:param batch: the batch of data
"""

@abc.abstractmethod
def delete_records(self, ids: set[str]) -> None:
"""
Deletes all mentioned IDs from the table.
:param ids: all IDs to remove
"""

def finalize(self) -> None:
"""
Performs any necessary cleanup after all batches have been written.
Expand Down
9 changes: 9 additions & 0 deletions cumulus_etl/formats/batched_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,12 @@ def _write_one_batch(self, batch: Batch) -> None:
full_path = self.dbroot.joinpath(f"{self.dbname}.{self._index:03}.{self.suffix}")
self.write_format(batch, full_path)
self._index += 1

def delete_records(self, ids: set[str]) -> None:
"""
Deletes the given IDs.
Though this is a no-op for batched file outputs, since:
- we guarantee the output folder is empty at the start
- the spec says deleted IDs won't overlap with output IDs
"""
40 changes: 30 additions & 10 deletions cumulus_etl/formats/deltalake.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,6 @@ def initialize_class(cls, root: store.Root) -> None:
def _write_one_batch(self, batch: Batch) -> None:
"""Writes the whole dataframe to a delta lake"""
with self.batch_to_spark(batch) as updates:
if updates is None:
return
delta_table = self.update_delta_table(updates, groups=batch.groups)

delta_table.generate("symlink_format_manifest")
Expand Down Expand Up @@ -131,16 +129,25 @@ def update_delta_table(

return table

def finalize(self) -> None:
"""Performs any necessary cleanup after all batches have been written"""
full_path = self._table_path(self.dbname)
def delete_records(self, ids: set[str]) -> None:
"""Deletes the given IDs."""
if not ids:
return

table = self._load_table()
if not table:
return

try:
table = delta.DeltaTable.forPath(self.spark, full_path)
except AnalysisException:
return # if the table doesn't exist because we didn't write anything, that's fine - just bail
id_list = "', '".join(ids)
table.delete(f"id in ('{id_list}')")
except Exception:
logging.exception("Could not finalize Delta Lake table %s", self.dbname)
logging.exception("Could not delete IDs from Delta Lake table %s", self.dbname)

def finalize(self) -> None:
"""Performs any necessary cleanup after all batches have been written"""
table = self._load_table()
if not table:
return

try:
Expand All @@ -154,6 +161,19 @@ def _table_path(self, dbname: str) -> str:
# hadoop uses the s3a: scheme instead of s3:
return self.root.joinpath(dbname).replace("s3://", "s3a://")

def _load_table(self) -> delta.DeltaTable | None:
full_path = self._table_path(self.dbname)

try:
return delta.DeltaTable.forPath(self.spark, full_path)
except AnalysisException:
# The table likely doesn't exist.
# Which can be normal if we didn't write anything yet, that's fine - just bail.
return None
except Exception:
logging.exception("Could not load Delta Lake table %s", self.dbname)
return None

@staticmethod
def _get_update_condition(schema: pyspark.sql.types.StructType) -> str | None:
"""
Expand Down Expand Up @@ -214,7 +234,7 @@ def _configure_fs(root: store.Root, spark: pyspark.sql.SparkSession):
spark.conf.set("fs.s3a.endpoint.region", region_name)

@contextlib.contextmanager
def batch_to_spark(self, batch: Batch) -> pyspark.sql.DataFrame | None:
def batch_to_spark(self, batch: Batch) -> pyspark.sql.DataFrame:
"""Transforms a batch to a spark DF"""
# This is the quick and dirty way - write batch to parquet with pyarrow and read it back.
# But a more direct way would be to convert the pyarrow schema to a pyspark schema and just
Expand Down
4 changes: 4 additions & 0 deletions cumulus_etl/loaders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ def path(self) -> str:
group_name: str | None = None
export_datetime: datetime.datetime | None = None

# A list of resource IDs that should be deleted from the output tables.
# This is a map of resource -> set of IDs like {"Patient": {"A", "B"}}
deleted_ids: dict[str, set[str]] = dataclasses.field(default_factory=dict)


class Loader(abc.ABC):
"""
Expand Down
27 changes: 27 additions & 0 deletions cumulus_etl/loaders/fhir/ndjson_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ async def load_all(self, resources: list[str]) -> base.LoaderResults:
# For now, just ignore any errors.
pass

results.deleted_ids = self.read_deleted_ids(input_root)

# Copy the resources we need from the remote directory (like S3 buckets) to a local one.
#
# We do this even if the files are local, because the next step in our pipeline is the MS deid tool,
Expand Down Expand Up @@ -112,3 +114,28 @@ async def load_from_bulk_export(
group_name=bulk_exporter.group_name,
export_datetime=bulk_exporter.export_datetime,
)

def read_deleted_ids(self, root: store.Root) -> dict[str, set[str]]:
"""
Reads any deleted IDs that a bulk export gave us.
See https://hl7.org/fhir/uv/bulkdata/export.html for details.
"""
deleted_ids = {}

subdir = store.Root(root.joinpath("deleted"))
bundles = common.read_resource_ndjson(subdir, "Bundle")
for bundle in bundles:
if bundle.get("type") != "transaction":
continue
for entry in bundle.get("entry", []):
request = entry.get("request", {})
if request.get("method") != "DELETE":
continue
url = request.get("url") # should be relative URL like "Patient/123"
if not url or url.count("/") != 1:
continue
resource, res_id = url.split("/")
deleted_ids.setdefault(resource, set()).add(res_id)

return deleted_ids
18 changes: 18 additions & 0 deletions tests/etl/test_etl_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,24 @@ async def test_completion_args(self, etl_args, loader_vals, expected_vals):
self.assertEqual(expected_vals[0], config.export_group_name)
self.assertEqual(expected_vals[1], config.export_datetime)

async def test_deleted_ids_passed_down(self):
"""Verify that we parse pass along any deleted ids to the JobConfig."""
with tempfile.TemporaryDirectory() as tmpdir:
results = loaders.LoaderResults(
directory=common.RealDirectory(tmpdir), deleted_ids={"Observation": {"obs1"}}
)

with (
self.assertRaises(SystemExit) as cm,

Check failure on line 298 in tests/etl/test_etl_cli.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (F841)

tests/etl/test_etl_cli.py:298:50: F841 Local variable `cm` is assigned to but never used
mock.patch("cumulus_etl.etl.cli.etl_job", side_effect=SystemExit) as mock_etl_job,
mock.patch.object(loaders.FhirNdjsonLoader, "load_all", return_value=results),
):
await self.run_etl(tasks=["observation"])

self.assertEqual(mock_etl_job.call_count, 1)
config = mock_etl_job.call_args[0][0]
self.assertEqual({"Observation": {"obs1"}}, config.deleted_ids)


class TestEtlJobConfig(BaseEtlSimple):
"""Test case for the job config logging data"""
Expand Down
41 changes: 41 additions & 0 deletions tests/etl/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import ddt

from cumulus_etl import common, errors
from cumulus_etl.etl import tasks
from cumulus_etl.etl.tasks import basic_tasks, task_factory
from tests.etl import TaskTestCase

Expand Down Expand Up @@ -133,6 +134,46 @@ async def test_batch_is_given_schema(self):
self.assertIn("address", schema.names)
self.assertIn("id", schema.names)

async def test_get_schema(self):
"""Verify that Task.get_schema() works for resources and non-resources"""
schema = tasks.EtlTask.get_schema("Patient", [])
self.assertIn("gender", schema.names)
schema = tasks.EtlTask.get_schema(None, [])
self.assertIsNone(schema)

async def test_prepare_can_skip_task(self):
"""Verify that if prepare_task returns false, we skip the task"""
self.make_json("Patient", "A")
with mock.patch(
"cumulus_etl.etl.tasks.basic_tasks.PatientTask.prepare_task", return_value=False
):
summaries = await basic_tasks.PatientTask(self.job_config, self.scrubber).run()
self.assertEqual(len(summaries), 1)
self.assertEqual(summaries[0].attempt, 0)
self.assertIsNone(self.format)

async def test_deleted_ids_no_op(self):
"""Verify that we don't try to delete IDs if none are given"""
# Just a simple test to confirm we don't even ask the formatter to consider
# deleting any IDs if we weren't given any.
await basic_tasks.PatientTask(self.job_config, self.scrubber).run()
self.assertEqual(self.format.delete_records.call_count, 0)

async def test_deleted_ids(self):
"""Verify that we send deleted IDs down to the formatter"""
self.job_config.deleted_ids = {"Patient": {"p1", "p2"}}
await basic_tasks.PatientTask(self.job_config, self.scrubber).run()

self.assertEqual(self.format.delete_records.call_count, 1)
ids = self.format.delete_records.call_args[0][0]
self.assertEqual(
ids,
{
self.codebook.db.resource_hash("p1"),
self.codebook.db.resource_hash("p2"),
},
)


@ddt.ddt
class TestTaskCompletion(TaskTestCase):
Expand Down
Loading

0 comments on commit 9ba91a0

Please sign in to comment.