Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: guard against accidental deletes with ndjson output format #281

Merged
merged 1 commit into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions cumulus_etl/formats/batched_files.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""An implementation of Format designed to write in batches of files"""

import abc
import re

from cumulus_etl import errors, store
from cumulus_etl.formats.base import Format
from cumulus_etl.formats.batch import Batch

Expand Down Expand Up @@ -42,11 +44,32 @@ def __init__(self, *args, **kwargs) -> None:
# Note: There is a real issue here where Athena will see invalid results until we've written all
# our files out. Use the deltalake format to get atomic updates.
parent_dir = self.root.joinpath(self.dbname)
self._confirm_no_unknown_files_exist(parent_dir)
try:
self.root.rm(parent_dir, recursive=True)
except FileNotFoundError:
pass

def _confirm_no_unknown_files_exist(self, folder: str) -> None:
"""
Errors out if any unknown files exist in the target dir already.

This is designed to prevent accidents.
"""
try:
filenames = [path.split("/")[-1] for path in store.Root(folder).ls()]
except FileNotFoundError:
return # folder doesn't exist, we're good!

allowed_pattern = re.compile(rf"{self.dbname}\.[0-9]+\.{self.suffix}")
if not all(map(allowed_pattern.fullmatch, filenames)):
errors.fatal(
f"There are unexpected files in the output folder '{folder}'.\n"
f"Please confirm you are using the right output format.\n"
f"If so, delete the output folder and try again.",
errors.FOLDER_NOT_EMPTY,
)

def _write_one_batch(self, batch: Batch) -> None:
"""Writes the whole dataframe to a single file"""
self.root.makedirs(self.root.joinpath(self.dbname))
Expand Down
14 changes: 5 additions & 9 deletions tests/etl/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,17 +101,17 @@ def setUp(self) -> None:
super().setUp()

client = fhir.FhirClient("http://localhost/", [])
self.tmpdir = tempfile.TemporaryDirectory() # pylint: disable=consider-using-with
self.input_dir = os.path.join(self.tmpdir.name, "input")
self.phi_dir = os.path.join(self.tmpdir.name, "phi")
self.errors_dir = os.path.join(self.tmpdir.name, "errors")
self.tmpdir = self.make_tempdir()
self.input_dir = os.path.join(self.tmpdir, "input")
self.phi_dir = os.path.join(self.tmpdir, "phi")
self.errors_dir = os.path.join(self.tmpdir, "errors")
os.makedirs(self.input_dir)
os.makedirs(self.phi_dir)

self.job_config = JobConfig(
self.input_dir,
self.input_dir,
self.tmpdir.name,
self.tmpdir,
self.phi_dir,
"ndjson",
"ndjson",
Expand Down Expand Up @@ -144,10 +144,6 @@ def make_formatter(dbname: str, group_field: str = None, resource_type: str = No
# Keeps consistent IDs
shutil.copy(os.path.join(self.datadir, "simple/codebook.json"), self.phi_dir)

def tearDown(self) -> None:
super().tearDown()
self.tmpdir = None

def make_json(self, filename, resource_id, **kwargs):
common.write_json(
os.path.join(self.input_dir, f"{filename}.ndjson"), {"resourceType": "Test", **kwargs, "id": resource_id}
Expand Down
Empty file added tests/formats/__init__.py
Empty file.
File renamed without changes.
72 changes: 72 additions & 0 deletions tests/formats/test_ndjson.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""Tests for ndjson output format support"""

import os

import ddt

from cumulus_etl import formats, store
from cumulus_etl.formats.ndjson import NdjsonFormat
from tests import utils


@ddt.ddt
class TestNdjsonFormat(utils.AsyncTestCase):
"""
Test case for the ndjson format writer.

i.e. tests for ndjson.py
"""

def setUp(self):
super().setUp()
self.output_tempdir = self.make_tempdir()
self.root = store.Root(self.output_tempdir)
NdjsonFormat.initialize_class(self.root)

@staticmethod
def df(**kwargs) -> list[dict]:
"""
Creates a dummy Table with ids & values equal to each kwarg provided.
"""
return [{"id": k, "value": v} for k, v in kwargs.items()]

def store(
self,
rows: list[dict],
batch_index: int = 10,
) -> bool:
"""
Writes a single batch of data to the output dir.

:param rows: the data to insert
:param batch_index: which batch number this is, defaulting to 10 to avoid triggering any first/last batch logic
"""
ndjson = NdjsonFormat(self.root, "condition")
batch = formats.Batch(rows, index=batch_index)
return ndjson.write_records(batch)

@ddt.data(
(None, True),
([], True),
(["condition.1234.ndjson", "condition.22.ndjson"], True),
(["condition.ndjson"], False),
(["condition.000.parquet"], False),
(["patient.000.ndjson"], False),
)
@ddt.unpack
def test_handles_existing_files(self, files: None | list[str], is_ok: bool):
"""Verify that we bail out if any weird files already exist in the output"""
dbpath = self.root.joinpath("condition")
if files is not None:
os.makedirs(dbpath)
for file in files:
with open(f"{dbpath}/{file}", "w", encoding="utf8") as f:
f.write('{"id": "A"}')

if is_ok:
self.store([{"id": "B"}], batch_index=0)
self.assertEqual(["condition.000.ndjson"], os.listdir(dbpath))
else:
with self.assertRaises(SystemExit):
self.store([{"id": "B"}])
self.assertEqual(files or [], os.listdir(dbpath))
10 changes: 5 additions & 5 deletions tests/test_bulk_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ class TestBulkExporter(AsyncTestCase):

def setUp(self):
super().setUp()
self.tmpdir = tempfile.TemporaryDirectory() # pylint: disable=consider-using-with
self.tmpdir = self.make_tempdir()
self.server = mock.AsyncMock()

def make_exporter(self, **kwargs) -> BulkExporter:
return BulkExporter(self.server, ["Condition", "Patient"], "https://localhost/", self.tmpdir.name, **kwargs)
return BulkExporter(self.server, ["Condition", "Patient"], "https://localhost/", self.tmpdir, **kwargs)

async def export(self, **kwargs) -> BulkExporter:
exporter = self.make_exporter(**kwargs)
Expand Down Expand Up @@ -79,9 +79,9 @@ async def test_happy_path(self):
self.server.request.call_args_list,
)

self.assertEqual({"type": "Condition1"}, common.read_json(f"{self.tmpdir.name}/Condition.000.ndjson"))
self.assertEqual({"type": "Condition2"}, common.read_json(f"{self.tmpdir.name}/Condition.001.ndjson"))
self.assertEqual({"type": "Patient1"}, common.read_json(f"{self.tmpdir.name}/Patient.000.ndjson"))
self.assertEqual({"type": "Condition1"}, common.read_json(f"{self.tmpdir}/Condition.000.ndjson"))
self.assertEqual({"type": "Condition2"}, common.read_json(f"{self.tmpdir}/Condition.001.ndjson"))
self.assertEqual({"type": "Patient1"}, common.read_json(f"{self.tmpdir}/Patient.000.ndjson"))

async def test_since_until(self):
"""Verify that we send since & until parameters correctly to the server"""
Expand Down
7 changes: 7 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import inspect
import json
import os
import tempfile
import time
import tracemalloc
import unittest
Expand Down Expand Up @@ -46,6 +47,12 @@ def setUp(self):
# Make it easy to grab test data, regardless of where the test is
self.datadir = os.path.join(os.path.dirname(__file__), "data")

def make_tempdir(self) -> str:
"""Creates a temporary dir that will be automatically cleaned up"""
tempdir = tempfile.TemporaryDirectory() # pylint: disable=consider-using-with
self.addCleanup(tempdir.cleanup)
return tempdir.name

def patch(self, *args, **kwargs) -> mock.Mock:
"""Syntactic sugar to ease making a mock over a test's lifecycle, without decorators"""
patcher = mock.patch(*args, **kwargs)
Expand Down