Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: use progress bars in a lot more places #256

Merged
merged 1 commit into from
Aug 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions cumulus_etl/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import contextlib
import csv
import datetime
import itertools
import json
import logging
import re
Expand Down Expand Up @@ -39,7 +40,7 @@ def __init__(self, path: str):
###############################################################################


def ls_resources(root, resource: str) -> list[str]:
def ls_resources(root: store.Root, resource: str) -> list[str]:
pattern = re.compile(rf".*/([0-9]+.)?{resource}(.[0-9]+)?.ndjson")
all_files = root.ls()
return sorted(filter(pattern.match, all_files))
Expand Down Expand Up @@ -130,7 +131,7 @@ def read_ndjson(path: str) -> Iterator[dict]:
yield json.loads(line)


def read_resource_ndjson(root, resource: str) -> Iterator[dict]:
def read_resource_ndjson(root: store.Root, resource: str) -> Iterator[dict]:
"""
Grabs all ndjson files from a folder, of a particular resource type.

Expand Down Expand Up @@ -189,6 +190,15 @@ def write(self, obj: dict) -> None:
self._file.write("\n")


def read_local_line_count(path) -> int:
"""Reads a local file and provides the count of new line characters."""
# From https://stackoverflow.com/a/27517681/239668
# Copyright Michael Bacon, licensed CC-BY-SA 3.0
with open(path, "rb") as f:
bufgen = itertools.takewhile(lambda x: x, (f.raw.read(1024 * 1024) for _ in itertools.repeat(None)))
return sum(buf.count(b"\n") for buf in bufgen if buf)


def sparse_dict(dictionary: dict) -> dict:
"""Returns a value of the input dictionary without any keys with None values."""

Expand Down
2 changes: 1 addition & 1 deletion cumulus_etl/deid/mstool.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ async def _wait_for_completion(process: asyncio.subprocess.Process, input_dir: s
stdout, stderr = None, None

with cli_utils.make_progress_bar() as progress:
task = progress.add_task("De-identifying data", total=1)
task = progress.add_task("De-identifying data", total=1)
target = _count_file_sizes(f"{input_dir}/*.ndjson")

while process.returncode is None:
Expand Down
8 changes: 6 additions & 2 deletions cumulus_etl/etl/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ async def etl_main(args: argparse.Namespace) -> None:

# Print configuration
print_config(args, job_datetime, selected_tasks)
common.print_header() # all "prep" comes in this next section, like connecting to server, bulk export, and de-id

if args.errors_to:
cli_utils.confirm_dir_is_empty(args.errors_to)
Expand Down Expand Up @@ -267,11 +268,14 @@ async def etl_main(args: argparse.Namespace) -> None:
job_context.last_successful_output_dir = args.dir_output
job_context.save()

# If any task had a failure, flag that for the user
# Flag final status to user
common.print_header()
failed = any(s.success < s.attempt for s in summaries)
if failed:
print("** One or more tasks above did not 100% complete! **", file=sys.stderr)
print("🚨 One or more tasks above did not 100% complete! 🚨", file=sys.stderr)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's the stuff

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I literally can't believe I didn't start with those

raise SystemExit(errors.TASK_FAILED)
else:
print("⭐ All tasks completed successfully! ⭐", file=sys.stderr)


async def run_etl(parser: argparse.ArgumentParser, argv: list[str]) -> None:
Expand Down
77 changes: 58 additions & 19 deletions cumulus_etl/etl/tasks/base.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
"""ETL tasks"""

import contextlib
import dataclasses
import os
from collections.abc import AsyncIterator, Iterator

import pyarrow
import rich.live
import rich.progress
import rich.table
import rich.text

from cumulus_etl import common, deid, fhir, formats, store
from cumulus_etl import cli_utils, common, deid, fhir, formats, store
from cumulus_etl.etl import config
from cumulus_etl.etl.tasks import batching

Expand Down Expand Up @@ -95,25 +100,35 @@ async def run(self) -> list[config.JobSummary]:
"""
common.print_header(f"{self.name}:")

if not await self.prepare_task():
return self.summaries
# Set up progress table with a slight left indent
grid = rich.table.Table.grid(padding=(0, 0, 0, 1), pad_edge=True)
progress = cli_utils.make_progress_bar()
text_box = rich.text.Text()
grid.add_row(progress)
grid.add_row(text_box)

entries = self.read_entries()
with rich.live.Live(grid):
with self._indeterminate_progress(progress, "Preparing"):
if not await self.prepare_task():
return self.summaries

# At this point we have a giant iterable of de-identified FHIR objects, ready to be written out.
# We want to batch them up, to allow resuming from interruptions more easily.
await self._write_tables_in_batches(entries)
entries = self.read_entries()
total_batches = self._count_total_batches()

# Ensure that we touch every output table (to create them and/or to confirm schema).
# Consider case of Medication for an EHR that only has inline Medications inside MedicationRequest.
# The Medication table wouldn't get created otherwise. Plus this is a good place to push any schema changes.
# (The reason it's nice if the table & schema exist is so that downstream SQL can be dumber.)
self._touch_remaining_tables()
# At this point we have a giant iterable of de-identified FHIR objects, ready to be written out.
# We want to batch them up, to allow resuming from interruptions more easily.
await self._write_tables_in_batches(entries, total=total_batches, progress=progress, status=text_box)

# All data is written, now do any final cleanup the formatters want
for table_index, formatter in enumerate(self.formatters):
formatter.finalize()
print(f" ⭐ done with {formatter.dbname} ({self.summaries[table_index].success:,} processed) ⭐")
with self._indeterminate_progress(progress, "Finalizing"):
# Ensure that we touch every output table (to create them and/or to confirm schema).
# Consider case of Medication for an EHR that only has inline Medications inside MedicationRequest.
# The Medication table wouldn't get created otherwise. Plus this is a good place to push any schema
# changes. (The reason it's nice if the table & schema exist is so that downstream SQL can be dumber.)
self._touch_remaining_tables()

# All data is written, now do any final cleanup the formatters want
for formatter in self.formatters:
formatter.finalize()

return self.summaries

Expand All @@ -128,9 +143,33 @@ def make_batch_from_rows(cls, formatter: formats.Format, rows: list[dict], index
#
##########################################################################################

async def _write_tables_in_batches(self, entries: EntryIterator) -> None:
@contextlib.contextmanager
def _indeterminate_progress(self, progress: rich.progress.Progress, description: str):
task = progress.add_task(description=description, total=None)
yield
progress.update(task, completed=1, total=1)

def _count_total_batches(self):
input_root = store.Root(self.task_config.dir_input)
filenames = common.ls_resources(input_root, self.resource)
line_count = sum(common.read_local_line_count(filename) for filename in filenames)
num_batches, remainder = divmod(line_count, self.task_config.batch_size)
if remainder:
num_batches += 1
return num_batches

async def _write_tables_in_batches(
self, entries: EntryIterator, *, total: int, progress: rich.progress.Progress, status: rich.text.Text
) -> None:
"""Writes all entries to each output tables in batches"""

def update_status():
status.plain = "\n".join(f"{x.success:,} processed for {x.label}" for x in self.summaries)

batch_index = 0
batch_task = progress.add_task(f"0/{total} batches", total=total)
update_status()

async for batches in batching.batch_iterate(entries, self.task_config.batch_size):
# Batches is a tuple of lists of resources - the tuple almost never matters, but it is there in case the
# task is generating multiple types of resources. Like MedicationRequest creating Medications as it goes.
Expand All @@ -146,12 +185,12 @@ async def _write_tables_in_batches(self, entries: EntryIterator) -> None:
summary.attempt += batch_len
if self._write_one_table_batch(formatter, rows, batch_index):
summary.success += batch_len
update_status()

self.table_batch_cleanup(table_index, batch_index)

print(f" {summary.success:,} processed for {formatter.dbname}")

batch_index += 1
progress.update(batch_task, description=f"{batch_index}/{total} batches", advance=1)

def _touch_remaining_tables(self):
"""Writes empty dataframe to any table we haven't written to yet"""
Expand Down
2 changes: 2 additions & 0 deletions cumulus_etl/fhir/fhir_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,8 @@ async def _read_capabilities(self) -> None:
if not self._server_root:
return

print("Connecting to server…")

try:
response = await self._session.get(
fhir_auth.urljoin(self._server_root, "metadata"),
Expand Down
73 changes: 40 additions & 33 deletions cumulus_etl/loaders/fhir/bulk_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import urllib.parse

import httpx
import rich.live
import rich.text

from cumulus_etl import common, errors, fhir

Expand Down Expand Up @@ -69,7 +71,7 @@ async def export(self) -> None:
See http://hl7.org/fhir/uv/bulkdata/export/index.html for details.
"""
# Initiate bulk export
common.print_header("Starting bulk FHIR export...")
print("Starting bulk FHIR export")

params = {"_type": ",".join(self._resources)}
if self._since:
Expand Down Expand Up @@ -104,7 +106,7 @@ async def export(self) -> None:
print("\n - ".join(["Messages from server:"] + warning_texts))

# Download all the files
print("Bulk FHIR export finished, now downloading resources...")
print("Bulk FHIR export finished, now downloading resources")
files = response_json.get("output", [])
await self._download_all_ndjson_files(files)
finally:
Expand Down Expand Up @@ -136,37 +138,42 @@ async def _request_with_delay(
:param method: HTTP method to request
:returns: the HTTP response
"""
while self._total_wait_time < self._TIMEOUT_THRESHOLD:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This whole change is basically just an indent to support a new Status region with one Text field that we write the ongoing updates to, instead of just printing them to the console.

response = await self._client.request(method, path, headers=headers)

if response.status_code == target_status_code:
return response

# 202 == server is still working on it, 429 == server is busy -- in both cases, we wait
if response.status_code in [202, 429]:
# Print a message to the user, so they don't see us do nothing for a while
delay = int(response.headers.get("Retry-After", 60))
if response.status_code == 202:
# Some servers can request unreasonably long delays (e.g. I've seen Cerner ask for five hours),
# which is... not helpful for our UX and often way too long for small exports.
# So as long as the server isn't telling us it's overloaded, limit the delay time to five minutes.
delay = min(delay, 300)
progress_msg = response.headers.get("X-Progress", "waiting...")
formatted_total = common.human_time_offset(self._total_wait_time)
formatted_delay = common.human_time_offset(delay)
print(f" {progress_msg} ({formatted_total} so far, waiting for {formatted_delay} more)")

# And wait as long as the server requests
await asyncio.sleep(delay)
self._total_wait_time += delay

else:
# It feels silly to abort on an unknown *success* code, but the spec has such clear guidance on
# what the expected response codes are, that it's not clear if a code outside those parameters means
# we should keep waiting or stop waiting. So let's be strict here for now.
raise errors.FatalError(
f"Unexpected status code {response.status_code} from the bulk FHIR export server."
)
status_box = rich.text.Text()
with rich.get_console().status(status_box) as status:
while self._total_wait_time < self._TIMEOUT_THRESHOLD:
response = await self._client.request(method, path, headers=headers)

if response.status_code == target_status_code:
if status_box.plain:
status.stop()
print(f" Waited for a total of {common.human_time_offset(self._total_wait_time)}")
return response

# 202 == server is still working on it, 429 == server is busy -- in both cases, we wait
if response.status_code in [202, 429]:
# Print a message to the user, so they don't see us do nothing for a while
delay = int(response.headers.get("Retry-After", 60))
if response.status_code == 202:
# Some servers can request unreasonably long delays (e.g. I've seen Cerner ask for five hours),
# which is... not helpful for our UX and often way too long for small exports.
# So as long as the server isn't telling us it's overloaded, limit the delay time to 5 minutes.
delay = min(delay, 300)
progress_msg = response.headers.get("X-Progress", "waiting…")
formatted_total = common.human_time_offset(self._total_wait_time)
formatted_delay = common.human_time_offset(delay)
status_box.plain = f"{progress_msg} ({formatted_total} so far, waiting for {formatted_delay} more)"

# And wait as long as the server requests
await asyncio.sleep(delay)
self._total_wait_time += delay

else:
# It feels silly to abort on an unknown *success* code, but the spec has such clear guidance on
# what the expected response codes are, that it's not clear if a code outside those parameters means
# we should keep waiting or stop waiting. So let's be strict here for now.
raise errors.FatalError(
f"Unexpected status code {response.status_code} from the bulk FHIR export server."
)

raise errors.FatalError("Timed out waiting for the bulk FHIR export to finish.")

Expand Down
2 changes: 1 addition & 1 deletion cumulus_etl/loaders/fhir/ndjson_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ async def load_all(self, resources: list[str]) -> common.Directory:
#
# This uses more disk space temporarily (copied files will get deleted once the MS tool is done and this
# TemporaryDirectory gets discarded), but that seems reasonable.
common.print_header("Copying ndjson input files…")
print("Copying ndjson input files…")
tmpdir = tempfile.TemporaryDirectory() # pylint: disable=consider-using-with
for resource in resources:
filenames = common.ls_resources(self.root, resource)
Expand Down