Skip to content

Commit

Permalink
Merge pull request #255 from smart-on-fhir/mikix/ordered-tasks
Browse files Browse the repository at this point in the history
feat: order tasks to minimize codebook writes
  • Loading branch information
mikix authored Jul 31, 2023
2 parents cf310fc + 7008466 commit a92ea1e
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 27 deletions.
56 changes: 29 additions & 27 deletions cumulus_etl/etl/tasks/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,17 @@ def get_all_tasks() -> list[type[AnyTask]]:
:returns: a list of all EtlTask subclasses, to instantiate and run
"""
# Right now, just hard-code these. One day we might allow plugins or something similarly dynamic.
# Note: tasks will be run in the order listed here.
return [
# Run encounter & patient first, to reduce churn on the codebook (the cached mappings would mostly be written
# out during the encounter task and wouldn't need to be re-written later, one would hope)
EncounterTask,
PatientTask,
# The rest of the tasks in alphabetical order, why not:
ConditionTask,
DocumentReferenceTask,
EncounterTask,
MedicationRequestTask,
ObservationTask,
PatientTask,
ProcedureTask,
ServiceRequestTask,
covid_symptom.CovidSymptomNlpResultsTask,
Expand All @@ -49,39 +53,37 @@ def get_selected_tasks(names: Iterable[str] = None, filter_tags: Iterable[str] =
:returns: a list of selected EtlTask subclasses, to instantiate and run
"""
all_tasks = get_all_tasks()
names = names and set(names)

# Filter out any tasks that don't have every required tag
filter_tag_set = frozenset(filter_tags or [])
filtered_tasks = filter(lambda x: filter_tag_set.issubset(x.tags), all_tasks)
filtered_tasks = list(filter(lambda x: filter_tag_set.issubset(x.tags), all_tasks))

# If the user didn't list any names, great! We're done.
if names is None:
selected_tasks = list(filtered_tasks)
if not selected_tasks:
if not names:
if not filtered_tasks:
print_filter_tags = ", ".join(sorted(filter_tag_set))
print(f"No tasks left after filtering for '{print_filter_tags}'.", file=sys.stderr)
raise SystemExit(errors.TASK_SET_EMPTY)
return selected_tasks
return filtered_tasks

# They did list names, so now we validate those names and select those tasks.
all_task_names = {t.name for t in all_tasks}
filtered_task_mapping = {t.name: t for t in filtered_tasks}
selected_tasks = []

for name in names:
if name not in all_task_names:
print_names = "\n".join(sorted(f" {key}" for key in all_task_names))
print(f"Unknown task '{name}' requested. Valid task names:\n{print_names}", file=sys.stderr)
raise SystemExit(errors.TASK_UNKNOWN)

if name not in filtered_task_mapping:
print_filter_tags = ", ".join(sorted(filter_tag_set))
print(
f"Task '{name}' requested but it does not match the task filter '{print_filter_tags}'.",
file=sys.stderr,
)
raise SystemExit(errors.TASK_FILTERED_OUT)

selected_tasks.append(filtered_task_mapping[name])

return selected_tasks
# Check for unknown names the user gave us
all_task_names = {t.name for t in all_tasks}
if unknown_names := names - all_task_names:
print_names = "\n".join(sorted(f" {key}" for key in all_task_names))
print(f"Unknown task '{unknown_names.pop()}' requested. Valid task names:\n{print_names}", file=sys.stderr)
raise SystemExit(errors.TASK_UNKNOWN)

# Check for names that conflict with the chosen filters
filtered_task_names = {t.name for t in filtered_tasks}
if unfiltered_names := names - filtered_task_names:
print_filter_tags = ", ".join(sorted(filter_tag_set))
print(
f"Task '{unfiltered_names.pop()}' requested but it does not match the task filter '{print_filter_tags}'.",
file=sys.stderr,
)
raise SystemExit(errors.TASK_FILTERED_OUT)

return [task for task in filtered_tasks if task.name in names]
15 changes: 15 additions & 0 deletions tests/etl/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,21 @@ def test_filtered_but_named_task(self):
tasks.get_selected_tasks(names=["condition"], filter_tags=["gpu"])
self.assertEqual(errors.TASK_FILTERED_OUT, cm.exception.code)

@ddt.data(
(None, "all"),
([], "all"),
(filter(None, []), "all"), # iterable, not list
(["observation", "condition", "procedure"], ["condition", "observation", "procedure"]), # re-ordered
(["condition", "patient", "encounter"], ["encounter", "patient", "condition"]), # encounter and patient first
)
@ddt.unpack
def test_task_selection_ordering(self, user_tasks, expected_tasks):
"""Verify we define the order, not the user, and that encounter & patient are early"""
names = [t.name for t in tasks.get_selected_tasks(names=user_tasks)]
if expected_tasks == "all":
expected_tasks = [t.name for t in tasks.get_all_tasks()]
self.assertEqual(expected_tasks, names)

async def test_drop_duplicates(self):
"""Verify that we run() will drop duplicate rows inside an input batch."""
# Two "A" ids and one "B" id
Expand Down

0 comments on commit a92ea1e

Please sign in to comment.