Skip to content

Commit

Permalink
fix: make sure that no inferred schemas ever hit a delta lake
Browse files Browse the repository at this point in the history
This commit fixes a couple oversights with inferred schemas:
- The type might clash with the wide schema we have per-spec, even
  though the types are really compatible (inferred int for a float
  field does happen)
- The wrong type might get through unnoticed if it is a deep field
  that our per-spec schema didn't catch.

Here's a summary of changes to make that happen:
- Drop all use of pandas. It's too loose with the types.
  Instead, switch to pyarrow which is used under the covers anyway.
- Add schema earlier in the process (at Task batching time, not at
  Formatter writing time). This means that all formatters get the
  same nice schema.
  • Loading branch information
mikix committed Jul 13, 2023
1 parent 75c68c9 commit 75824a9
Show file tree
Hide file tree
Showing 28 changed files with 547 additions and 393 deletions.
67 changes: 55 additions & 12 deletions cumulus_etl/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from collections.abc import Iterator
from typing import Any, TextIO

import pyarrow

from cumulus_etl import store


Expand Down Expand Up @@ -37,10 +39,14 @@ def _atomic_open(path: str, mode: str) -> TextIO:
"""A version of open() that handles atomic file access across many filesystems (like S3)"""
root = store.Root(path)

# fsspec is atomic per-transaction -- if an error occurs inside the transaction, partial writes will be discarded
with root.fs.transaction:
with root.fs.open(path, mode=mode, encoding="utf8") as file:
yield file
with contextlib.ExitStack() as stack:
if "w" in mode:
# fsspec is atomic per-transaction.
# If an error occurs inside the transaction, partial writes will be discarded.
# But we only want a transaction if we're writing - read transactions may error out
stack.enter_context(root.fs.transaction)

yield stack.enter_context(root.fs.open(path, mode=mode, encoding="utf8"))


def read_text(path: str) -> str:
Expand Down Expand Up @@ -94,7 +100,7 @@ def write_json(path: str, data: Any, indent: int = None) -> None:

@contextlib.contextmanager
def read_csv(path: str) -> csv.DictReader:
with open(path, newline="", encoding="utf8") as csvfile:
with _atomic_open(path, "r") as csvfile:
yield csv.DictReader(csvfile)


Expand All @@ -115,13 +121,35 @@ def read_resource_ndjson(root, resource: str) -> Iterator[dict]:
yield from read_ndjson(filename)


def write_table_to_ndjson(path: str, table: pyarrow.Table, sparse: bool = False) -> None:
"""
Writes the pyarrow Table out, row by row, to an .ndjson file (non-atomically).
:param path: where to write the file
:param table: data to write
:param sparse: if True, None entries are skipped
"""
with NdjsonWriter(path, allow_empty=True) as f:
for batch in table.to_batches():
for row in batch.to_pylist():
if sparse:
row = sparse_dict(row)
f.write(row)


class NdjsonWriter:
"""Convenience context manager to write multiple objects to a local ndjson file."""
"""
Convenience context manager to write multiple objects to a ndjson file.
def __init__(self, path: str, mode: str = "w"):
self._path = path
Note that this is not atomic - partial writes will make it to the target file.
"""

def __init__(self, path: str, mode: str = "w", allow_empty: bool = False):
self._root = store.Root(path)
self._mode = mode
self._file = None
if allow_empty:
self._ensure_file()

def __enter__(self):
return self
Expand All @@ -131,15 +159,31 @@ def __exit__(self, exc_type, exc_value, traceback):
self._file.close()
self._file = None

def write(self, obj: dict) -> None:
# lazily create the file, to avoid 0-line ndjson files
def _ensure_file(self):
if not self._file:
self._file = open(self._path, self._mode, encoding="utf8") # pylint: disable=consider-using-with
self._file = self._root.fs.open(self._root.path, self._mode, encoding="utf8")

def write(self, obj: dict) -> None:
# lazily create the file, to avoid 0-line ndjson files (unless created in __init__)
self._ensure_file()

json.dump(obj, self._file)
self._file.write("\n")


def sparse_dict(dictionary: dict) -> dict:
"""Returns a value of the input dictionary without any keys with None values."""

def iteration(item: Any) -> Any:
if isinstance(item, dict):
return {key: iteration(val) for key, val in item.items() if val is not None}
elif isinstance(item, list):
return [iteration(x) for x in item]
return item

return iteration(dictionary)


###############################################################################
#
# Helper Functions: Logging
Expand All @@ -162,7 +206,6 @@ def human_file_size(count: int) -> str:
Returns a human-readable version of a count of bytes.
I couldn't find a version of this that's sitting in a library we use. Very annoying.
Pandas has one, but it's private.
"""
for suffix in ("KB", "MB"):
count /= 1024
Expand Down
2 changes: 1 addition & 1 deletion cumulus_etl/etl/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ async def etl_main(args: argparse.Namespace) -> None:

async with client:
if args.input_format == "i2b2":
config_loader = loaders.I2b2Loader(root_input, args.batch_size, export_to=args.export_to)
config_loader = loaders.I2b2Loader(root_input, export_to=args.export_to)
else:
config_loader = loaders.FhirNdjsonLoader(
root_input, client=client, export_to=args.export_to, since=args.since, until=args.until
Expand Down
8 changes: 3 additions & 5 deletions cumulus_etl/etl/convert/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import os
import tempfile

import pandas
import rich.progress

from cumulus_etl import cli_utils, common, errors, formats, store
Expand Down Expand Up @@ -61,10 +60,9 @@ def convert_task_table(
progress_task = progress.add_task(table.get_name(task), total=count)

for index, ndjson_path in enumerate(ndjson_paths):
rows = common.read_ndjson(ndjson_path)
df = pandas.DataFrame(rows)
df.drop_duplicates("id", inplace=True)
formatter.write_records(df, index)
rows = list(common.read_ndjson(ndjson_path))
table = task.make_table_from_batch(formatter, rows)
formatter.write_records(table, index)
progress.update(progress_task, advance=1)

formatter.finalize()
Expand Down
39 changes: 38 additions & 1 deletion cumulus_etl/etl/studies/covid_symptom/covid_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
import os

import ctakesclient
import pyarrow

from cumulus_etl import common, nlp, store
from cumulus_etl import common, formats, nlp, store
from cumulus_etl.etl import tasks
from cumulus_etl.etl.studies.covid_symptom import covid_ctakes

Expand Down Expand Up @@ -118,3 +119,39 @@ async def read_entries(self) -> tasks.EntryIterator:
# This way we don't need to worry about symptoms from the same note crossing batch boundaries.
# The Format class will replace all existing symptoms from this note at once (because we set group_field).
yield symptoms

@classmethod
def get_schema(cls, formatter: formats.Format, batch: list[dict]) -> pyarrow.Schema:
return pyarrow.schema(
[
pyarrow.field("id", pyarrow.string()),
pyarrow.field("docref_id", pyarrow.string()),
pyarrow.field("encounter_id", pyarrow.string()),
pyarrow.field("subject_id", pyarrow.string()),
pyarrow.field(
"match",
pyarrow.struct(
[
pyarrow.field("begin", pyarrow.int32()),
pyarrow.field("end", pyarrow.int32()),
pyarrow.field("text", pyarrow.string()),
pyarrow.field("polarity", pyarrow.int8()),
pyarrow.field("type", pyarrow.string()),
pyarrow.field(
"conceptAttributes",
pyarrow.list_(
pyarrow.struct(
[
pyarrow.field("code", pyarrow.string()),
pyarrow.field("codingScheme", pyarrow.string()),
pyarrow.field("cui", pyarrow.string()),
pyarrow.field("tui", pyarrow.string()),
]
)
),
),
]
),
),
]
)
75 changes: 55 additions & 20 deletions cumulus_etl/etl/tasks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import os
from collections.abc import AsyncIterator, Iterator

import pandas
import pyarrow

from cumulus_etl import common, deid, formats, store
from cumulus_etl import common, deid, fhir, formats, store
from cumulus_etl.etl import config
from cumulus_etl.etl.tasks import batching

Expand Down Expand Up @@ -117,6 +117,11 @@ async def run(self) -> list[config.JobSummary]:

return self.summaries

@classmethod
def make_table_from_batch(cls, formatter: formats.Format, batch: list[dict]):
schema = cls.get_schema(formatter, batch)
return pyarrow.Table.from_pylist(batch, schema=schema)

##########################################################################################
#
# Internal helpers
Expand Down Expand Up @@ -173,41 +178,60 @@ def _get_formatter(self, table_index: int) -> formats.Format:

return self.formatters[table_index]

def _uniquify_batch(self, batch: list[dict]) -> list[dict]:
"""
Drop duplicates inside the batch to guarantee to the formatter that the "id" column is unique.
This does not fix uniqueness across batches, but formatters that care about that can control for it.
For context:
- We have seen duplicates inside and across files generated by Cerner bulk exports. So this is a real
concern found in the wild, and we can't just trust input data to be "clean."
- The deltalake backend in particular would prefer the ID to be at least unique inside a batch, so that
it can more easily handle merge logic. Duplicate IDs across batches will be naturally overwritten as
new batches are merged in.
- Other backends like ndjson can currently just live with duplicates across batches, that's fine.
"""
id_set = set()

def is_unique(row):
nonlocal id_set
if row["id"] in id_set:
return False
id_set.add(row["id"])
return True

return [row for row in batch if is_unique(row)]

def _write_one_table_batch(self, formatter: formats.Format, batch: list[dict], batch_index: int) -> bool:
# Start by stuffing the batch entries into a dataframe
dataframe = pandas.DataFrame(batch)

# Drop duplicates inside the batch to guarantee to the formatter that the "id" column is unique.
# This does not fix uniqueness across batches, but formatters that care about that can control for it.
# For context:
# - We have seen duplicates inside and across files generated by Cerner bulk exports. So this is a real
# concern found in the wild, and we can't just trust input data to be "clean."
# - The deltalake backend in particular would prefer the ID to be at least unique inside a batch, so that
# it can more easily handle merge logic. Duplicate IDs across batches will be naturally overwritten as
# new batches are merged in.
# - Other backends like ndjson can currently just live with duplicates across batches, that's fine.
dataframe.drop_duplicates("id", inplace=True)
batch = self._uniquify_batch(batch)

# Start by stuffing the batch entries into a pyarrow Table
schema = self.get_schema(formatter, batch)
table = pyarrow.Table.from_pylist(batch, schema=schema)

# Checkpoint scrubber data before writing to the store, because if we get interrupted, it's safer to have an
# updated codebook with no data than data with an inaccurate codebook.
self.scrubber.save()

# Now we write that DataFrame to the target folder, in the requested format (e.g. parquet).
success = formatter.write_records(dataframe, batch_index)
success = formatter.write_records(table, batch_index)
if not success:
# We should write the "bad" dataframe to the error dir, for later review
self._write_errors(dataframe, batch_index)
# We should write the "bad" batch to the error dir, for later review
self._write_errors(batch, batch_index)

return success

def _write_errors(self, df: pandas.DataFrame, index: int) -> None:
def _write_errors(self, batch: list[dict], index: int) -> None:
"""Takes the dataframe and writes it to the error dir, if one was provided"""
if not self.task_config.dir_errors:
return

error_root = store.Root(os.path.join(self.task_config.dir_errors, self.name), create=True)
error_path = error_root.joinpath(f"write-error.{index:03}.ndjson")
df.to_json(error_path, orient="records", lines=True, storage_options=error_root.fsspec_options())
with common.NdjsonWriter(error_path) as f:
for row in batch:
f.write(row)

##########################################################################################
#
Expand Down Expand Up @@ -253,3 +277,14 @@ async def prepare_task(self) -> bool:
:returns: False if this task should be skipped and end immediately
"""
return True

@classmethod
def get_schema(cls, formatter: formats.Format, batch: list[dict]) -> pyarrow.Schema | None:
"""
Creates a properly-schema'd Table from the provided batch.
Can be overridden as needed for non-FHIR outputs.
"""
if formatter.resource_type:
return fhir.pyarrow_schema_from_resource_batch(formatter.resource_type, batch)
return None
2 changes: 1 addition & 1 deletion cumulus_etl/fhir/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Support for talking to FHIR servers & handling the FHIR spec"""

from .fhir_client import FhirClient, create_fhir_client_for_cli
from .fhir_schemas import create_spark_schema_for_resource
from .fhir_schemas import pyarrow_schema_from_resource_batch
from .fhir_utils import download_reference, get_docref_note, ref_resource, unref_resource
Loading

0 comments on commit 75824a9

Please sign in to comment.