Skip to content

Commit

Permalink
Merge pull request #247 from smart-on-fhir/mikix/schema-but-moreso
Browse files Browse the repository at this point in the history
fix: make sure that no inferred schemas ever hit a delta lake
  • Loading branch information
mikix authored Jul 19, 2023
2 parents 75c68c9 + 5265647 commit bab85b6
Show file tree
Hide file tree
Showing 30 changed files with 651 additions and 447 deletions.
67 changes: 54 additions & 13 deletions cumulus_etl/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,18 @@ def ls_resources(root, resource: str) -> list[str]:


@contextlib.contextmanager
def _atomic_open(path: str, mode: str) -> TextIO:
def _atomic_open(path: str, mode: str, **kwargs) -> TextIO:
"""A version of open() that handles atomic file access across many filesystems (like S3)"""
root = store.Root(path)

# fsspec is atomic per-transaction -- if an error occurs inside the transaction, partial writes will be discarded
with root.fs.transaction:
with root.fs.open(path, mode=mode, encoding="utf8") as file:
yield file
with contextlib.ExitStack() as stack:
if "w" in mode:
# fsspec is atomic per-transaction.
# If an error occurs inside the transaction, partial writes will be discarded.
# But we only want a transaction if we're writing - read transactions may error out
stack.enter_context(root.fs.transaction)

yield stack.enter_context(root.fs.open(path, mode=mode, encoding="utf8", **kwargs))


def read_text(path: str) -> str:
Expand Down Expand Up @@ -94,7 +98,8 @@ def write_json(path: str, data: Any, indent: int = None) -> None:

@contextlib.contextmanager
def read_csv(path: str) -> csv.DictReader:
with open(path, newline="", encoding="utf8") as csvfile:
# Python docs say to use newline="", to support quoted multi-line fields
with _atomic_open(path, "r", newline="") as csvfile:
yield csv.DictReader(csvfile)


Expand All @@ -115,13 +120,34 @@ def read_resource_ndjson(root, resource: str) -> Iterator[dict]:
yield from read_ndjson(filename)


def write_rows_to_ndjson(path: str, rows: list[dict], sparse: bool = False) -> None:
"""
Writes the data out, row by row, to an .ndjson file (non-atomically).
:param path: where to write the file
:param rows: data to write
:param sparse: if True, None entries are skipped
"""
with NdjsonWriter(path, allow_empty=True) as f:
for row in rows:
if sparse:
row = sparse_dict(row)
f.write(row)


class NdjsonWriter:
"""Convenience context manager to write multiple objects to a local ndjson file."""
"""
Convenience context manager to write multiple objects to a ndjson file.
Note that this is not atomic - partial writes will make it to the target file.
"""

def __init__(self, path: str, mode: str = "w"):
self._path = path
def __init__(self, path: str, mode: str = "w", allow_empty: bool = False):
self._root = store.Root(path)
self._mode = mode
self._file = None
if allow_empty:
self._ensure_file()

def __enter__(self):
return self
Expand All @@ -131,15 +157,31 @@ def __exit__(self, exc_type, exc_value, traceback):
self._file.close()
self._file = None

def write(self, obj: dict) -> None:
# lazily create the file, to avoid 0-line ndjson files
def _ensure_file(self):
if not self._file:
self._file = open(self._path, self._mode, encoding="utf8") # pylint: disable=consider-using-with
self._file = self._root.fs.open(self._root.path, self._mode, encoding="utf8")

def write(self, obj: dict) -> None:
# lazily create the file, to avoid 0-line ndjson files (unless created in __init__)
self._ensure_file()

json.dump(obj, self._file)
self._file.write("\n")


def sparse_dict(dictionary: dict) -> dict:
"""Returns a value of the input dictionary without any keys with None values."""

def iteration(item: Any) -> Any:
if isinstance(item, dict):
return {key: iteration(val) for key, val in item.items() if val is not None}
elif isinstance(item, list):
return [iteration(x) for x in item]
return item

return iteration(dictionary)


###############################################################################
#
# Helper Functions: Logging
Expand All @@ -162,7 +204,6 @@ def human_file_size(count: int) -> str:
Returns a human-readable version of a count of bytes.
I couldn't find a version of this that's sitting in a library we use. Very annoying.
Pandas has one, but it's private.
"""
for suffix in ("KB", "MB"):
count /= 1024
Expand Down
2 changes: 1 addition & 1 deletion cumulus_etl/etl/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ async def etl_main(args: argparse.Namespace) -> None:

async with client:
if args.input_format == "i2b2":
config_loader = loaders.I2b2Loader(root_input, args.batch_size, export_to=args.export_to)
config_loader = loaders.I2b2Loader(root_input, export_to=args.export_to)
else:
config_loader = loaders.FhirNdjsonLoader(
root_input, client=client, export_to=args.export_to, since=args.since, until=args.until
Expand Down
8 changes: 3 additions & 5 deletions cumulus_etl/etl/convert/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import os
import tempfile

import pandas
import rich.progress

from cumulus_etl import cli_utils, common, errors, formats, store
Expand Down Expand Up @@ -61,10 +60,9 @@ def convert_task_table(
progress_task = progress.add_task(table.get_name(task), total=count)

for index, ndjson_path in enumerate(ndjson_paths):
rows = common.read_ndjson(ndjson_path)
df = pandas.DataFrame(rows)
df.drop_duplicates("id", inplace=True)
formatter.write_records(df, index)
rows = list(common.read_ndjson(ndjson_path))
batch = task.make_batch_from_rows(formatter, rows, index=index)
formatter.write_records(batch)
progress.update(progress_task, advance=1)

formatter.finalize()
Expand Down
39 changes: 38 additions & 1 deletion cumulus_etl/etl/studies/covid_symptom/covid_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
import os

import ctakesclient
import pyarrow

from cumulus_etl import common, nlp, store
from cumulus_etl import common, formats, nlp, store
from cumulus_etl.etl import tasks
from cumulus_etl.etl.studies.covid_symptom import covid_ctakes

Expand Down Expand Up @@ -118,3 +119,39 @@ async def read_entries(self) -> tasks.EntryIterator:
# This way we don't need to worry about symptoms from the same note crossing batch boundaries.
# The Format class will replace all existing symptoms from this note at once (because we set group_field).
yield symptoms

@classmethod
def get_schema(cls, formatter: formats.Format, rows: list[dict]) -> pyarrow.Schema:
return pyarrow.schema(
[
pyarrow.field("id", pyarrow.string()),
pyarrow.field("docref_id", pyarrow.string()),
pyarrow.field("encounter_id", pyarrow.string()),
pyarrow.field("subject_id", pyarrow.string()),
pyarrow.field(
"match",
pyarrow.struct(
[
pyarrow.field("begin", pyarrow.int32()),
pyarrow.field("end", pyarrow.int32()),
pyarrow.field("text", pyarrow.string()),
pyarrow.field("polarity", pyarrow.int8()),
pyarrow.field("type", pyarrow.string()),
pyarrow.field(
"conceptAttributes",
pyarrow.list_(
pyarrow.struct(
[
pyarrow.field("code", pyarrow.string()),
pyarrow.field("codingScheme", pyarrow.string()),
pyarrow.field("cui", pyarrow.string()),
pyarrow.field("tui", pyarrow.string()),
]
)
),
),
]
),
),
]
)
84 changes: 57 additions & 27 deletions cumulus_etl/etl/tasks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import os
from collections.abc import AsyncIterator, Iterator

import pandas
import pyarrow

from cumulus_etl import common, deid, formats, store
from cumulus_etl import common, deid, fhir, formats, store
from cumulus_etl.etl import config
from cumulus_etl.etl.tasks import batching

Expand Down Expand Up @@ -117,6 +117,11 @@ async def run(self) -> list[config.JobSummary]:

return self.summaries

@classmethod
def make_batch_from_rows(cls, formatter: formats.Format, rows: list[dict], index: int = 0):
schema = cls.get_schema(formatter, rows)
return formats.Batch(rows, schema=schema, index=index)

##########################################################################################
#
# Internal helpers
Expand All @@ -130,16 +135,16 @@ async def _write_tables_in_batches(self, entries: EntryIterator) -> None:
# Batches is a tuple of lists of resources - the tuple almost never matters, but it is there in case the
# task is generating multiple types of resources. Like MedicationRequest creating Medications as it goes.
# Each tuple of batches collectively adds up to roughly our target batch size.
for table_index, batch in enumerate(batches):
if not batch:
for table_index, rows in enumerate(batches):
if not rows:
continue

formatter = self._get_formatter(table_index)
batch_len = len(batch)
batch_len = len(rows)

summary = self.summaries[table_index]
summary.attempt += batch_len
if self._write_one_table_batch(formatter, batch, batch_index):
if self._write_one_table_batch(formatter, rows, batch_index):
summary.success += batch_len

self.table_batch_cleanup(table_index, batch_index)
Expand Down Expand Up @@ -173,41 +178,55 @@ def _get_formatter(self, table_index: int) -> formats.Format:

return self.formatters[table_index]

def _write_one_table_batch(self, formatter: formats.Format, batch: list[dict], batch_index: int) -> bool:
# Start by stuffing the batch entries into a dataframe
dataframe = pandas.DataFrame(batch)

# Drop duplicates inside the batch to guarantee to the formatter that the "id" column is unique.
# This does not fix uniqueness across batches, but formatters that care about that can control for it.
# For context:
# - We have seen duplicates inside and across files generated by Cerner bulk exports. So this is a real
# concern found in the wild, and we can't just trust input data to be "clean."
# - The deltalake backend in particular would prefer the ID to be at least unique inside a batch, so that
# it can more easily handle merge logic. Duplicate IDs across batches will be naturally overwritten as
# new batches are merged in.
# - Other backends like ndjson can currently just live with duplicates across batches, that's fine.
dataframe.drop_duplicates("id", inplace=True)
def _uniquify_rows(self, rows: list[dict]) -> list[dict]:
"""
Drop duplicates inside the batch to guarantee to the formatter that the "id" column is unique.
This does not fix uniqueness across batches, but formatters that care about that can control for it.
For context:
- We have seen duplicates inside and across files generated by Cerner bulk exports. So this is a real
concern found in the wild, and we can't just trust input data to be "clean."
- The deltalake backend in particular would prefer the ID to be at least unique inside a batch, so that
it can more easily handle merge logic. Duplicate IDs across batches will be naturally overwritten as
new batches are merged in.
- Other backends like ndjson can currently just live with duplicates across batches, that's fine.
"""
id_set = set()

def is_unique(row):
nonlocal id_set
if row["id"] in id_set:
return False
id_set.add(row["id"])
return True

return [row for row in rows if is_unique(row)]

def _write_one_table_batch(self, formatter: formats.Format, rows: list[dict], batch_index: int) -> bool:
# Checkpoint scrubber data before writing to the store, because if we get interrupted, it's safer to have an
# updated codebook with no data than data with an inaccurate codebook.
self.scrubber.save()

# Now we write that DataFrame to the target folder, in the requested format (e.g. parquet).
success = formatter.write_records(dataframe, batch_index)
rows = self._uniquify_rows(rows)
batch = self.make_batch_from_rows(formatter, rows, index=batch_index)

# Now we write that batch to the target folder, in the requested format (e.g. parquet).
success = formatter.write_records(batch)
if not success:
# We should write the "bad" dataframe to the error dir, for later review
self._write_errors(dataframe, batch_index)
# We should write the "bad" batch to the error dir, for later review
self._write_errors(batch)

return success

def _write_errors(self, df: pandas.DataFrame, index: int) -> None:
def _write_errors(self, batch: formats.Batch) -> None:
"""Takes the dataframe and writes it to the error dir, if one was provided"""
if not self.task_config.dir_errors:
return

error_root = store.Root(os.path.join(self.task_config.dir_errors, self.name), create=True)
error_path = error_root.joinpath(f"write-error.{index:03}.ndjson")
df.to_json(error_path, orient="records", lines=True, storage_options=error_root.fsspec_options())
error_path = error_root.joinpath(f"write-error.{batch.index:03}.ndjson")
common.write_rows_to_ndjson(error_path, batch.rows)

##########################################################################################
#
Expand Down Expand Up @@ -253,3 +272,14 @@ async def prepare_task(self) -> bool:
:returns: False if this task should be skipped and end immediately
"""
return True

@classmethod
def get_schema(cls, formatter: formats.Format, rows: list[dict]) -> pyarrow.Schema | None:
"""
Creates a properly-schema'd Table from the provided batch.
Can be overridden as needed for non-FHIR outputs.
"""
if formatter.resource_type:
return fhir.pyarrow_schema_from_resource_batch(formatter.resource_type, rows)
return None
2 changes: 1 addition & 1 deletion cumulus_etl/fhir/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Support for talking to FHIR servers & handling the FHIR spec"""

from .fhir_client import FhirClient, create_fhir_client_for_cli
from .fhir_schemas import create_spark_schema_for_resource
from .fhir_schemas import pyarrow_schema_from_resource_batch
from .fhir_utils import download_reference, get_docref_note, ref_resource, unref_resource
Loading

0 comments on commit bab85b6

Please sign in to comment.