diff --git a/process_report/processors/add_institution_processor.py b/process_report/processors/add_institution_processor.py index 1440c49..d5896a5 100644 --- a/process_report/processors/add_institution_processor.py +++ b/process_report/processors/add_institution_processor.py @@ -14,28 +14,6 @@ @dataclass class AddInstitutionProcessor(processor.Processor): - @staticmethod - def _get_institute_mapping(institute_list: list): - institute_map = dict() - for institute_info in institute_list: - for domain in institute_info["domains"]: - institute_map[domain] = institute_info["display_name"] - - return institute_map - - @staticmethod - def _get_institution_from_pi(institute_map, pi_uname): - institution_domain = pi_uname.split("@")[-1] - for i in range(institution_domain.count(".") + 1): - if institution_name := institute_map.get(institution_domain, ""): - break - institution_domain = institution_domain[institution_domain.find(".") + 1 :] - - if institution_name == "": - logger.warning(f"PI name {pi_uname} does not match any institution!") - - return institution_name - def _add_institution(self): """Determine every PI's institution name, logging any PI whose institution cannot be determined This is performed by `get_institution_from_pi()`, which tries to match the PI's username to @@ -49,7 +27,7 @@ def _add_institution(self): The list of mappings are defined in `institute_map.json`. """ institute_list = util.load_institute_list() - institute_map = self._get_institute_mapping(institute_list) + institute_map = util.get_institute_mapping(institute_list) self.data = self.data.astype({invoice.INSTITUTION_FIELD: "str"}) for i, row in self.data.iterrows(): pi_name = row[invoice.PI_FIELD] @@ -58,7 +36,7 @@ def _add_institution(self): else: self.data.at[ i, invoice.INSTITUTION_FIELD - ] = self._get_institution_from_pi(institute_map, pi_name) + ] = util.get_institution_from_pi(institute_map, pi_name) def _process(self): self._add_institution() diff --git a/process_report/tests/unit/processors/test_add_institution_processor.py b/process_report/tests/unit/processors/test_add_institution_processor.py index 775cb23..a634c74 100644 --- a/process_report/tests/unit/processors/test_add_institution_processor.py +++ b/process_report/tests/unit/processors/test_add_institution_processor.py @@ -1,9 +1,10 @@ from unittest import TestCase +from process_report import util from process_report.tests import util as test_utils -class TestAddInstituteProcessor(TestCase): +class TestAddInstitute(TestCase): def test_get_pi_institution(self): institute_map = { "harvard.edu": "Harvard University", @@ -36,6 +37,5 @@ def test_get_pi_institution(self): for pi_email, answer in answers.items(): self.assertEqual( - add_institute_proc._get_institution_from_pi(institute_map, pi_email), - answer, + util.get_institution_from_pi(institute_map, pi_email), answer )