diff --git a/process_report/processors/add_institution_processor.py b/process_report/processors/add_institution_processor.py index ecd2530..e86ec91 100644 --- a/process_report/processors/add_institution_processor.py +++ b/process_report/processors/add_institution_processor.py @@ -9,28 +9,6 @@ @dataclass class AddInstitutionProcessor(processor.Processor): - @staticmethod - def _get_institute_mapping(institute_list: list): - institute_map = dict() - for institute_info in institute_list: - for domain in institute_info["domains"]: - institute_map[domain] = institute_info["display_name"] - - return institute_map - - @staticmethod - def _get_institution_from_pi(institute_map, pi_uname): - institution_domain = pi_uname.split("@")[-1] - for i in range(institution_domain.count(".") + 1): - if institution_name := institute_map.get(institution_domain, ""): - break - institution_domain = institution_domain[institution_domain.find(".") + 1 :] - - if institution_name == "": - print(f"Warning: PI name {pi_uname} does not match any institution!") - - return institution_name - def _add_institution(self): """Determine every PI's institution name, logging any PI whose institution cannot be determined This is performed by `get_institution_from_pi()`, which tries to match the PI's username to @@ -44,7 +22,7 @@ def _add_institution(self): The list of mappings are defined in `institute_map.json`. """ institute_list = util.load_institute_list() - institute_map = self._get_institute_mapping(institute_list) + institute_map = util.get_institute_mapping(institute_list) self.data = self.data.astype({invoice.INSTITUTION_FIELD: "str"}) for i, row in self.data.iterrows(): pi_name = row[invoice.PI_FIELD] @@ -53,7 +31,7 @@ def _add_institution(self): else: self.data.at[ i, invoice.INSTITUTION_FIELD - ] = self._get_institution_from_pi(institute_map, pi_name) + ] = util.get_institution_from_pi(institute_map, pi_name) def _process(self): self._add_institution() diff --git a/process_report/tests/unit_tests.py b/process_report/tests/unit_tests.py index c2a7e7b..0743848 100644 --- a/process_report/tests/unit_tests.py +++ b/process_report/tests/unit_tests.py @@ -157,7 +157,7 @@ def test_export_pi(self, mock_filter_cols): self.assertNotIn("ProjectC", pi_df["Project - Allocation"].tolist()) -class TestAddInstituteProcessor(TestCase): +class TestAddInstitute(TestCase): def test_get_pi_institution(self): institute_map = { "harvard.edu": "Harvard University", @@ -186,12 +186,9 @@ def test_get_pi_institution(self): "g@bidmc.harvard.edu": "Beth Israel Deaconess Medical Center", } - add_institute_proc = test_utils.new_add_institution_processor() - for pi_email, answer in answers.items(): self.assertEqual( - add_institute_proc._get_institution_from_pi(institute_map, pi_email), - answer, + util.get_institution_from_pi(institute_map, pi_email), answer )