diff --git a/cumulus_etl/fhir/fhir_auth.py b/cumulus_etl/fhir/fhir_auth.py new file mode 100644 index 00000000..8dd5110e --- /dev/null +++ b/cumulus_etl/fhir/fhir_auth.py @@ -0,0 +1,244 @@ +"""Code for the various ways to authenticate against a FHIR server""" + +import base64 +import sys +import time +import urllib.parse +import uuid +from json import JSONDecodeError +from collections.abc import Iterable + +import httpx +from jwcrypto import jwk, jwt + +from cumulus_etl import errors + + +def urljoin(base: str, path: str) -> str: + """Basically just urllib.parse.urljoin, but with some extra error checking""" + path_is_absolute = bool(urllib.parse.urlparse(path).netloc) + if path_is_absolute: + return path + + if not base: + print("You must provide a base FHIR server URL with --fhir-url", file=sys.stderr) + raise SystemExit(errors.FHIR_URL_MISSING) + return urllib.parse.urljoin(base, path) + + +class Auth: + """Abstracted authentication for a FHIR server. By default, does nothing.""" + + async def authorize(self, session: httpx.AsyncClient, reauthorize=False) -> None: + """Authorize (or re-authorize) against the server""" + del session + + if reauthorize: + # Abort because we clearly need authentication tokens, but have not been given any parameters for them. + print( + "You must provide some authentication parameters (like --smart-client-id) to connect to a server.", + file=sys.stderr, + ) + raise SystemExit(errors.SMART_CREDENTIALS_MISSING) + + def sign_headers(self, headers: dict) -> dict: + """Add signature token to request headers""" + return headers + + +class JwksAuth(Auth): + """Authentication with a JWK Set (typical backend service profile)""" + + def __init__(self, server_root: str, client_id: str, jwks: dict, resources: Iterable[str]): + super().__init__() + self._server_root = server_root + self._client_id = client_id + self._jwks = jwks + self._resources = list(resources) + self._token_endpoint = None + self._access_token = None + + async def authorize(self, session: httpx.AsyncClient, reauthorize=False) -> None: + """ + Authenticates against a SMART FHIR server using the Backend Services profile. + + See https://hl7.org/fhir/smart-app-launch/backend-services.html for details. + """ + if self._token_endpoint is None: # grab URL if we haven't before + self._token_endpoint = await self._get_token_endpoint(session) + + auth_params = { + "grant_type": "client_credentials", + "scope": " ".join([f"system/{resource}.read" for resource in self._resources]), + "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", + "client_assertion": self._make_signed_jwt(), + } + + try: + response = await session.post(self._token_endpoint, data=auth_params) + response.raise_for_status() + self._access_token = response.json().get("access_token") + except httpx.HTTPStatusError as exc: + try: + response_json = exc.response.json() + except JSONDecodeError: + response_json = {} + message = response_json.get("error_description") # standard oauth2 error field + if not message and "error_uri" in response_json: + # Another standard oauth2 error field, which Cerner usually gives back, and it does have helpful info + message = f'visit "{response_json.get("error_uri")}" for more details' + if not message: + message = str(exc) + + errors.fatal(f"Could not authenticate with the FHIR server: {message}", errors.FHIR_AUTH_FAILED) + + def sign_headers(self, headers: dict) -> dict: + """Add signature token to request headers""" + headers["Authorization"] = f"Bearer {self._access_token}" + return headers + + async def _get_token_endpoint(self, session: httpx.AsyncClient) -> str: + """ + Returns the oauth2 token endpoint for a SMART FHIR server. + + See https://hl7.org/fhir/smart-app-launch/client-confidential-asymmetric.html for details. + + If the server does not support the client-confidential-asymmetric protocol, an exception will be raised. + + :returns: URL for the server's oauth2 token endpoint + """ + response = await session.get( + urljoin(self._server_root, ".well-known/smart-configuration"), + headers={ + "Accept": "application/json", + }, + timeout=300, # five minutes + ) + response.raise_for_status() + + # Validate that the server can talk the client-confidential-asymmetric protocol with us. + # Some servers (like Cerner) don't advertise their support with the 'client-confidential-asymmetric' + # capability keyword, so let's not bother checking for it. But we can confirm that the pieces are there. + config = response.json() + if "private_key_jwt" not in config.get("token_endpoint_auth_methods_supported", []) or not config.get( + "token_endpoint" + ): + errors.fatal( + f"Server {self._server_root} does not support the client-confidential-asymmetric protocol", + errors.FHIR_AUTH_FAILED, + ) + + return config["token_endpoint"] + + def _make_signed_jwt(self) -> str: + """ + Creates a signed JWT for use in the client-confidential-asymmetric protocol. + + See https://hl7.org/fhir/smart-app-launch/client-confidential-asymmetric.html for details. + + :returns: a signed JWT string, ready for authentication with the FHIR server + """ + # Find a usable singing JWK from JWKS + for key in self._jwks.get("keys", []): + if key.get("alg") in ["ES384", "RS384"] and "sign" in key.get("key_ops", []) and key.get("kid"): + break + else: # no valid private JWK found + raise errors.FatalError("No private ES384 or RS384 key found in the provided JWKS file.") + + # Now generate a signed JWT based off the given JWK + header = { + "alg": key["alg"], + "kid": key["kid"], + "typ": "JWT", + } + claims = { + "iss": self._client_id, + "sub": self._client_id, + "aud": self._token_endpoint, + "exp": int(time.time()) + 299, # expires inside five minutes + "jti": str(uuid.uuid4()), + } + token = jwt.JWT(header=header, claims=claims) + token.make_signed_token(key=jwk.JWK(**key)) + return token.serialize() + + +class BasicAuth(Auth): + """Authentication with basic user/password""" + + def __init__(self, user: str, password: str): + super().__init__() + # Assume utf8 is acceptable -- we should in theory also run these through Unicode normalization, in case they + # have interesting Unicode characters. But we can always add that in the future. + combo_bytes = f"{user}:{password}".encode("utf8") + self._basic_token = base64.standard_b64encode(combo_bytes).decode("ascii") + + async def authorize(self, session: httpx.AsyncClient, reauthorize=False) -> None: + pass + + def sign_headers(self, headers: dict) -> dict: + headers["Authorization"] = f"Basic {self._basic_token}" + return headers + + +class BearerAuth(Auth): + """Authentication with a static bearer token""" + + def __init__(self, bearer_token: str): + super().__init__() + self._bearer_token = bearer_token + + async def authorize(self, session: httpx.AsyncClient, reauthorize=False) -> None: + pass + + def sign_headers(self, headers: dict) -> dict: + headers["Authorization"] = f"Bearer {self._bearer_token}" + return headers + + +def create_auth( + server_root: str | None, + resources: Iterable[str], + basic_user: str | None, + basic_password: str | None, + bearer_token: str | None, + smart_client_id: str | None, + smart_jwks: dict | None, +) -> Auth: + """Determine which auth method to use based on user provided arguments""" + valid_smart_jwks = smart_jwks is not None # compared to a falsy (but technically usable) empty dict for example + + # Check if the user tried to specify multiple types of auth, and help them out + has_basic_args = bool(basic_user or basic_password) + has_bearer_args = bool(bearer_token) + has_smart_args = bool(valid_smart_jwks) + total_auth_types = has_basic_args + has_bearer_args + has_smart_args + if total_auth_types > 1: + print( + "Multiple authentication methods have been specified. Double check your arguments to Cumulus ETL.", + file=sys.stderr, + ) + raise SystemExit(errors.ARGS_CONFLICT) + + if basic_user and basic_password: + return BasicAuth(basic_user, basic_password) + elif basic_user or basic_password: + print( + "You must provide both --basic-user and --basic-password to connect to a Basic auth server.", + file=sys.stderr, + ) + raise SystemExit(errors.BASIC_CREDENTIALS_MISSING) + + if bearer_token: + return BearerAuth(bearer_token) + + if smart_client_id and valid_smart_jwks: + return JwksAuth(server_root, smart_client_id, smart_jwks, resources) + elif smart_client_id or valid_smart_jwks: + print( + "You must provide both --smart-client-id and --smart-jwks to connect to a SMART FHIR server.", + file=sys.stderr, + ) + raise SystemExit(errors.SMART_CREDENTIALS_MISSING) + + return Auth() diff --git a/cumulus_etl/fhir/fhir_client.py b/cumulus_etl/fhir/fhir_client.py index c1dd277d..072bd79f 100644 --- a/cumulus_etl/fhir/fhir_client.py +++ b/cumulus_etl/fhir/fhir_client.py @@ -1,201 +1,22 @@ """HTTP client that talk to a FHIR server""" import argparse -import base64 +import enum import re import sys -import time -import urllib.parse -import uuid from json import JSONDecodeError from collections.abc import Iterable import httpx -from jwcrypto import jwk, jwt from cumulus_etl import common, errors, store +from cumulus_etl.fhir import fhir_auth -def _urljoin(base: str, path: str) -> str: - """Basically just urllib.parse.urljoin, but with some extra error checking""" - path_is_absolute = bool(urllib.parse.urlparse(path).netloc) - if path_is_absolute: - return path - - if not base: - print("You must provide a base FHIR server URL with --fhir-url", file=sys.stderr) - raise SystemExit(errors.FHIR_URL_MISSING) - return urllib.parse.urljoin(base, path) - - -class Auth: - """Abstracted authentication for a FHIR server. By default, does nothing.""" - - async def authorize(self, session: httpx.AsyncClient, reauthorize=False) -> None: - """Authorize (or re-authorize) against the server""" - del session - - if reauthorize: - # Abort because we clearly need authentication tokens, but have not been given any parameters for them. - print( - "You must provide some authentication parameters (like --smart-client-id) to connect to a server.", - file=sys.stderr, - ) - raise SystemExit(errors.SMART_CREDENTIALS_MISSING) - - def sign_headers(self, headers: dict) -> dict: - """Add signature token to request headers""" - return headers - - -class JwksAuth(Auth): - """Authentication with a JWK Set (typical backend service profile)""" - - def __init__(self, server_root: str, client_id: str, jwks: dict, resources: Iterable[str]): - super().__init__() - self._server_root = server_root - self._client_id = client_id - self._jwks = jwks - self._resources = list(resources) - self._token_endpoint = None - self._access_token = None - - async def authorize(self, session: httpx.AsyncClient, reauthorize=False) -> None: - """ - Authenticates against a SMART FHIR server using the Backend Services profile. - - See https://hl7.org/fhir/smart-app-launch/backend-services.html for details. - """ - if self._token_endpoint is None: # grab URL if we haven't before - self._token_endpoint = await self._get_token_endpoint(session) - - auth_params = { - "grant_type": "client_credentials", - "scope": " ".join([f"system/{resource}.read" for resource in self._resources]), - "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", - "client_assertion": self._make_signed_jwt(), - } - - try: - response = await session.post(self._token_endpoint, data=auth_params) - response.raise_for_status() - self._access_token = response.json().get("access_token") - except httpx.HTTPStatusError as exc: - try: - response_json = exc.response.json() - except JSONDecodeError: - response_json = {} - message = response_json.get("error_description") # standard oauth2 error field - if not message and "error_uri" in response_json: - # Another standard oauth2 error field, which Cerner usually gives back, and it does have helpful info - message = f'visit "{response_json.get("error_uri")}" for more details' - if not message: - message = str(exc) - - errors.fatal(f"Could not authenticate with the FHIR server: {message}", errors.FHIR_AUTH_FAILED) - - def sign_headers(self, headers: dict) -> dict: - """Add signature token to request headers""" - headers["Authorization"] = f"Bearer {self._access_token}" - return headers - - async def _get_token_endpoint(self, session: httpx.AsyncClient) -> str: - """ - Returns the oauth2 token endpoint for a SMART FHIR server. - - See https://hl7.org/fhir/smart-app-launch/client-confidential-asymmetric.html for details. - - If the server does not support the client-confidential-asymmetric protocol, an exception will be raised. - - :returns: URL for the server's oauth2 token endpoint - """ - response = await session.get( - _urljoin(self._server_root, ".well-known/smart-configuration"), - headers={ - "Accept": "application/json", - }, - timeout=300, # five minutes - ) - response.raise_for_status() - - # Validate that the server can talk the client-confidential-asymmetric protocol with us. - # Some servers (like Cerner) don't advertise their support with the 'client-confidential-asymmetric' - # capability keyword, so let's not bother checking for it. But we can confirm that the pieces are there. - config = response.json() - if "private_key_jwt" not in config.get("token_endpoint_auth_methods_supported", []) or not config.get( - "token_endpoint" - ): - errors.fatal( - f"Server {self._server_root} does not support the client-confidential-asymmetric protocol", - errors.FHIR_AUTH_FAILED, - ) - - return config["token_endpoint"] - - def _make_signed_jwt(self) -> str: - """ - Creates a signed JWT for use in the client-confidential-asymmetric protocol. - - See https://hl7.org/fhir/smart-app-launch/client-confidential-asymmetric.html for details. - - :returns: a signed JWT string, ready for authentication with the FHIR server - """ - # Find a usable singing JWK from JWKS - for key in self._jwks.get("keys", []): - if key.get("alg") in ["ES384", "RS384"] and "sign" in key.get("key_ops", []) and key.get("kid"): - break - else: # no valid private JWK found - raise errors.FatalError("No private ES384 or RS384 key found in the provided JWKS file.") - - # Now generate a signed JWT based off the given JWK - header = { - "alg": key["alg"], - "kid": key["kid"], - "typ": "JWT", - } - claims = { - "iss": self._client_id, - "sub": self._client_id, - "aud": self._token_endpoint, - "exp": int(time.time()) + 299, # expires inside five minutes - "jti": str(uuid.uuid4()), - } - token = jwt.JWT(header=header, claims=claims) - token.make_signed_token(key=jwk.JWK(**key)) - return token.serialize() - - -class BasicAuth(Auth): - """Authentication with basic user/password""" - - def __init__(self, user: str, password: str): - super().__init__() - # Assume utf8 is acceptable -- we should in theory also run these through Unicode normalization, in case they - # have interesting Unicode characters. But we can always add that in the future. - combo_bytes = f"{user}:{password}".encode("utf8") - self._basic_token = base64.standard_b64encode(combo_bytes).decode("ascii") - - async def authorize(self, session: httpx.AsyncClient, reauthorize=False) -> None: - pass - - def sign_headers(self, headers: dict) -> dict: - headers["Authorization"] = f"Basic {self._basic_token}" - return headers - - -class BearerAuth(Auth): - """Authentication with a static bearer token""" - - def __init__(self, bearer_token: str): - super().__init__() - self._bearer_token = bearer_token - - async def authorize(self, session: httpx.AsyncClient, reauthorize=False) -> None: - pass - - def sign_headers(self, headers: dict) -> dict: - headers["Authorization"] = f"Bearer {self._bearer_token}" - return headers +class ServerType(enum.Enum): + UNKNOWN = enum.auto() + CERNER = enum.auto() + EPIC = enum.auto() class FhirClient: @@ -233,13 +54,22 @@ def __init__( self._server_root = url # all requests are relative to this URL if self._server_root and not self._server_root.endswith("/"): self._server_root += "/" # This will ensure the last segment does not get chopped off by urljoin - self._auth = self._make_auth(resources, basic_user, basic_password, bearer_token, smart_client_id, smart_jwks) + + self._client_id = smart_client_id + self._server_type = ServerType.UNKNOWN + self._auth = fhir_auth.create_auth( + self._server_root, resources, basic_user, basic_password, bearer_token, smart_client_id, smart_jwks + ) self._session: httpx.AsyncClient | None = None async def __aenter__(self): # Limit the number of connections open at once, because EHRs tend to be very busy. limits = httpx.Limits(max_connections=5) - self._session = httpx.AsyncClient(limits=limits, timeout=300) # five minutes to be generous + timeout = 300 # five minutes to be generous + # Follow redirects by default -- some EHRs definitely use them for bulk download files, + # and might use them in other cases, who knows. + self._session = httpx.AsyncClient(limits=limits, timeout=timeout, follow_redirects=True) + await self._read_capabilities() # discover server type, etc await self._auth.authorize(self._session) return self @@ -264,7 +94,7 @@ async def request(self, method: str, path: str, headers: dict = None, stream: bo :param stream: whether to stream content in or load it all into memory at once :returns: The response object """ - url = _urljoin(self._server_root, path) + url = fhir_auth.urljoin(self._server_root, path) final_headers = { "Accept": "application/fhir+json", @@ -315,52 +145,38 @@ async def request(self, method: str, path: str, headers: dict = None, stream: bo # ################################################################################################################### - def _make_auth( - self, - resources: Iterable[str], - basic_user: str, - basic_password: str, - bearer_token: str, - smart_client_id: str, - smart_jwks: dict, - ) -> Auth: - """Determine which auth method to use based on user provided arguments""" - valid_smart_jwks = smart_jwks is not None # compared to a falsy (but technically usable) empty dict for example - - # Check if the user tried to specify multiple types of auth, and help them out - has_basic_args = bool(basic_user or basic_password) - has_bearer_args = bool(bearer_token) - has_smart_args = bool(smart_client_id or valid_smart_jwks) - total_auth_types = has_basic_args + has_bearer_args + has_smart_args - if total_auth_types > 1: - print( - "Multiple authentication methods have been specified. Double check your arguments to Cumulus ETL.", - file=sys.stderr, - ) - raise SystemExit(errors.ARGS_CONFLICT) - - if basic_user and basic_password: - return BasicAuth(basic_user, basic_password) - elif basic_user or basic_password: - print( - "You must provide both --basic-user and --basic-password to connect to a Basic auth server.", - file=sys.stderr, - ) - raise SystemExit(errors.BASIC_CREDENTIALS_MISSING) + async def _read_capabilities(self) -> None: + """ + Reads the server's CapabilityStatement and sets any properties as a result (like server/vendor type). - if bearer_token: - return BearerAuth(bearer_token) + This is expected to be called extremely early, right as the http session is opened. + """ + if not self._server_root: + return - if smart_client_id and valid_smart_jwks: - return JwksAuth(self._server_root, smart_client_id, smart_jwks, resources) - elif smart_client_id or valid_smart_jwks: - print( - "You must provide both --smart-client-id and --smart-jwks to connect to a SMART FHIR server.", - file=sys.stderr, + try: + response = await self._session.get( + fhir_auth.urljoin(self._server_root, "metadata"), + headers={ + "Accept": "application/json", + "Accept-Charset": "UTF-8", + }, ) - raise SystemExit(errors.SMART_CREDENTIALS_MISSING) + response.raise_for_status() + except httpx.HTTPError: + return # That's fine - just skip this optional metadata - return Auth() + try: + capabilities = response.json() + except JSONDecodeError: + return + + if capabilities.get("publisher") == "Cerner": + # Example: https://fhir-ehr-code.cerner.com/r4/ec2458f2-1e24-41c8-b71b-0e701af7583d/metadata?_format=json + self._server_type = ServerType.CERNER + elif capabilities.get("software", {}).get("name") == "Epic": + # Example: https://fhir.epic.com/interconnect-fhir-oauth/api/FHIR/R4/metadata?_format=json + self._server_type = ServerType.EPIC async def _request_with_signed_headers(self, method: str, url: str, headers: dict, **kwargs) -> httpx.Response: """ @@ -374,11 +190,15 @@ async def _request_with_signed_headers(self, method: str, url: str, headers: dic if not self._session: raise RuntimeError("FhirClient must be used as a context manager") + # Epic wants to see the Epic-Client-ID header, especially for non-OAuth flows. + # (but I've heard reports of also wanting it in OAuth flows too) + # See https://fhir.epic.com/Documentation?docId=oauth2§ion=NonOauth_Epic-Client-ID-Header + if self._server_type == ServerType.EPIC and self._client_id: + headers["Epic-Client-ID"] = self._client_id + headers = self._auth.sign_headers(headers) request = self._session.build_request(method, url, headers=headers) - # Follow redirects by default -- some EHRs definitely use them for bulk download files, - # and might use them in other cases, who knows. - return await self._session.send(request, follow_redirects=True, **kwargs) + return await self._session.send(request, **kwargs) def create_fhir_client_for_cli( diff --git a/tests/covid_symptom/test_nlp_results.py b/tests/covid_symptom/test_nlp_results.py index 7e6d298d..bff51e9b 100644 --- a/tests/covid_symptom/test_nlp_results.py +++ b/tests/covid_symptom/test_nlp_results.py @@ -118,13 +118,13 @@ async def test_bad_doc_status_is_skipped_for_covid_symptoms(self, status: dict, ([("file-cough", "text/nope")], None), # ignores unsupported mimetypes ) @ddt.unpack - @respx.mock - async def test_note_urls_downloaded(self, attachments, expected_text): + @respx.mock(assert_all_mocked=False, assert_all_called=False) + async def test_note_urls_downloaded(self, attachments, expected_text, respx_mock): """Verify that we download any attachments with URLs""" # We return three words due to how our cTAKES mock works. It wants 3 words -- fever word is in middle. - respx.get("http://localhost/file-cough").respond(text="has cough bad") - respx.get("http://localhost/file-fever").respond(text="has fever bad") - respx.post(os.environ["URL_CTAKES_REST"]).pass_through() # ignore cTAKES + respx_mock.get("http://localhost/file-cough").respond(text="has cough bad") + respx_mock.get("http://localhost/file-fever").respond(text="has fever bad") + respx_mock.post(os.environ["URL_CTAKES_REST"]).pass_through() # ignore cTAKES docref0 = i2b2_mock_data.documentreference() docref0["content"] = [{"attachment": {"url": a[0], "contentType": a[1]}} for a in attachments] diff --git a/tests/fhir/test_fhir_client.py b/tests/fhir/test_fhir_client.py index 974497a7..cd642723 100644 --- a/tests/fhir/test_fhir_client.py +++ b/tests/fhir/test_fhir_client.py @@ -13,7 +13,7 @@ @ddt.ddt -@mock.patch("cumulus_etl.fhir.fhir_client.uuid.uuid4", new=lambda: "1234") +@mock.patch("cumulus_etl.fhir.fhir_auth.uuid.uuid4", new=lambda: "1234") class TestFhirClient(AsyncTestCase): """ Test case for FHIR client oauth2 / request support. @@ -68,6 +68,9 @@ def setUp(self): json=self.smart_configuration, ) + # empty capabilities (no vendor quirks) by default + self.respx_mock.get(f"{self.server_url}/metadata").respond(json={}) + self.respx_mock.post( self.token_url, name="token", @@ -303,3 +306,44 @@ def test_added_binary_scope(self, resources_in, expected_resources_out, mock_cli ) fhir.create_fhir_client_for_cli(args, store.Root("/tmp"), resources_in) self.assertEqual(mock_client.call_args[0][1], expected_resources_out) + + +@ddt.ddt +@mock.patch("cumulus_etl.fhir.fhir_auth.uuid.uuid4", new=lambda: "1234") +class TestFhirClientEpicQuirks(AsyncTestCase): + """Test case for FHIR client handling of Epic-specific vendor quirks.""" + + def setUp(self): + super().setUp() + self.server_url = "http://localhost" + + self.respx_mock = respx.mock(assert_all_called=False) + self.addCleanup(self.respx_mock.stop) + self.respx_mock.start() + + def mock_as_server_type(self, server_type: str | None): + response_json = {} + if server_type == "epic": + response_json = {"software": {"name": "Epic"}} + + self.respx_mock.get(f"{self.server_url}/metadata").respond(json=response_json) + + @ddt.data( + ("epic", "present"), + (None, "missing"), + ) + @ddt.unpack + async def test_client_id_in_header(self, server_type, expected_text): + # Mock with header + self.respx_mock.get(f"{self.server_url}/file", headers={"Epic-Client-ID": "my-id"},).respond( + text="present", + ) + # And without + self.respx_mock.get(f"{self.server_url}/file",).respond( + text="missing", + ) + + self.mock_as_server_type(server_type) + async with fhir.FhirClient(self.server_url, [], bearer_token="foo", smart_client_id="my-id") as server: + response = await server.request("GET", "file") + self.assertEqual(expected_text, response.text) diff --git a/tests/test_bulk_export.py b/tests/test_bulk_export.py index 248541e3..2001118e 100644 --- a/tests/test_bulk_export.py +++ b/tests/test_bulk_export.py @@ -258,6 +258,11 @@ def setUp(self) -> None: self.jwks_path = self.jwks_file.name def set_up_requests(self, respx_mock): + # /metadata + respx_mock.get( + f"{self.root.path}/metadata", + ).respond(json={}) + # /.well-known/smart-configuration respx_mock.get( f"{self.root.path}/.well-known/smart-configuration", diff --git a/tests/test_chart_cli.py b/tests/test_chart_cli.py index 8d5de660..acc459d3 100644 --- a/tests/test_chart_cli.py +++ b/tests/test_chart_cli.py @@ -123,7 +123,7 @@ def make_docref(doc_id: str, text: str = None, content: list[dict] = None) -> di } @staticmethod - def mock_search_url(patient: str, doc_ids: Iterable[str]) -> None: + def mock_search_url(respx_mock: respx.MockRouter, patient: str, doc_ids: Iterable[str]) -> None: bundle = { "resourceType": "Bundle", "entry": [ @@ -134,12 +134,12 @@ def mock_search_url(patient: str, doc_ids: Iterable[str]) -> None: ], } - respx.get(f"https://localhost/DocumentReference?patient={patient}&_elements=content").respond(json=bundle) + respx_mock.get(f"https://localhost/DocumentReference?patient={patient}&_elements=content").respond(json=bundle) @staticmethod - def mock_read_url(doc_id: str, code: int = 200, **kwargs) -> None: + def mock_read_url(respx_mock: respx.MockRouter, doc_id: str, code: int = 200, **kwargs) -> None: docref = TestChartReview.make_docref(doc_id, **kwargs) - respx.get(f"https://localhost/DocumentReference/{doc_id}").respond(status_code=code, json=docref) + respx_mock.get(f"https://localhost/DocumentReference/{doc_id}").respond(status_code=code, json=docref) @staticmethod def write_anon_docrefs(path: str, ids: list[tuple[str, str]]) -> None: @@ -169,11 +169,11 @@ async def test_real_and_fake_docrefs_conflict(self): await self.run_chart_review(anon_docrefs="foo", docrefs="bar") self.assertEqual(errors.ARGS_CONFLICT, cm.exception.code) - @respx.mock - async def test_gather_anon_docrefs_from_server(self): - self.mock_search_url("P1", ["NotMe", "D1", "NotThis", "D3"]) - self.mock_search_url("P2", ["D2"]) - respx.post(os.environ["URL_CTAKES_REST"]).pass_through() # ignore cTAKES + @respx.mock(assert_all_mocked=False) + async def test_gather_anon_docrefs_from_server(self, respx_mock): + self.mock_search_url(respx_mock, "P1", ["NotMe", "D1", "NotThis", "D3"]) + self.mock_search_url(respx_mock, "P2", ["D2"]) + respx_mock.post(os.environ["URL_CTAKES_REST"]).pass_through() # ignore cTAKES with tempfile.NamedTemporaryFile() as file: self.write_anon_docrefs( @@ -190,13 +190,13 @@ async def test_gather_anon_docrefs_from_server(self): self.assertEqual({"D1", "D2", "D3"}, self.get_exported_ids()) self.assertEqual({"D1", "D2", "D3"}, self.get_pushed_ids()) - @respx.mock - async def test_gather_real_docrefs_from_server(self): - self.mock_read_url("D1") - self.mock_read_url("D2") - self.mock_read_url("D3") - self.mock_read_url("unknown-doc", code=404) - respx.post(os.environ["URL_CTAKES_REST"]).pass_through() # ignore cTAKES + @respx.mock(assert_all_mocked=False) + async def test_gather_real_docrefs_from_server(self, respx_mock): + self.mock_read_url(respx_mock, "D1") + self.mock_read_url(respx_mock, "D2") + self.mock_read_url(respx_mock, "D3") + self.mock_read_url(respx_mock, "unknown-doc", code=404) + respx_mock.post(os.environ["URL_CTAKES_REST"]).pass_through() # ignore cTAKES with tempfile.NamedTemporaryFile() as file: self.write_real_docrefs(file.name, ["D1", "D2", "D3", "unknown-doc"])