Skip to content

Commit

Permalink
Allows old PI file to be fetched from S3, updated during processing, …
Browse files Browse the repository at this point in the history
…and uploaded back to S3
  • Loading branch information
QuanMPhm committed Apr 24, 2024
1 parent 6d74b3d commit ec07b44
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 5 deletions.
37 changes: 32 additions & 5 deletions process_report/process_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,14 @@ def load_old_pis(old_pi_file):
return old_pi_dict


def dump_old_pis(old_pi_file, old_pi_dict: dict):
with open(old_pi_file, "w") as f:
for pi, first_month in old_pi_dict.items():
f.write(f"{pi},{first_month}\n")


def is_old_pi(old_pi_dict, pi, invoice_month):
# If we're processing old invoices, this will flag PIs who were new *back then* as old. But this would only be the case if future PIs appeared in past invoices???
if pi in old_pi_dict and old_pi_dict[pi] != invoice_month:
return True
return False
Expand Down Expand Up @@ -101,7 +108,7 @@ def main():
parser.add_argument(
"--upload-to-s3",
action="store_true",
help="If set, uploads all processed invoices to S3",
help="If set, uploads all processed invoices and old PI file to S3",
)
parser.add_argument(
"--invoice-month",
Expand Down Expand Up @@ -163,16 +170,20 @@ def main():
parser.add_argument(
"--old-pi-file",
required=False,
help="Name of csv file listing previously billed PIs",
help="Name of csv file listing previously billed PIs. If not provided, defaults to fetching from S3",
)
args = parser.parse_args()

invoice_month = args.invoice_month

if args.fetch_from_s3:
csv_files = fetch_S3_invoices(invoice_month)
csv_files = fetch_s3_invoices(invoice_month)
else:
csv_files = args.csv_files
if args.old_pi_file:
old_pi_file = args.old_pi_file
else:
old_pi_file = fetch_s3_old_pi_file()

merged_dataframe = merge_csv(csv_files)

Expand All @@ -196,7 +207,7 @@ def main():

billable_projects = remove_non_billables(merged_dataframe, pi, projects)
billable_projects = validate_pi_names(billable_projects)
credited_projects = apply_credits_new_pi(billable_projects, args.old_pi_file)
credited_projects = apply_credits_new_pi(billable_projects, old_pi_file)

export_billables(credited_projects, args.output_file)
export_pi_billables(credited_projects, args.output_folder, invoice_month)
Expand All @@ -217,9 +228,10 @@ def main():
invoice_list.append(os.path.join(args.output_folder, pi_invoice))

upload_to_s3(invoice_list, invoice_month)
upload_to_s3_old_pi_file(old_pi_file)


def fetch_S3_invoices(invoice_month):
def fetch_s3_invoices(invoice_month):
"""Fetches usage invoices from S3 given invoice month"""
s3_invoice_list = list()
invoice_bucket = get_invoice_bucket()
Expand Down Expand Up @@ -320,6 +332,7 @@ def apply_credits_new_pi(dataframe, old_pi_file):
for i, row in pi_projects.iterrows():
dataframe.at[i, BALANCE_FIELD] = row[COST_FIELD]
else:
old_pi_dict[pi] = invoice_month
remaining_credit = new_pi_credit_amount
for i, row in pi_projects.iterrows():
project_cost = row[COST_FIELD]
Expand All @@ -333,9 +346,23 @@ def apply_credits_new_pi(dataframe, old_pi_file):
if remaining_credit == 0:
break

dump_old_pis(old_pi_file, old_pi_dict)

return dataframe


def fetch_s3_old_pi_file():
local_name = "PI.csv"
invoice_bucket = get_invoice_bucket()
invoice_bucket.download_file("PIs/PI.csv", local_name)
return local_name


def upload_to_s3_old_pi_file(old_pi_file):
invoice_bucket = get_invoice_bucket()
invoice_bucket.upload_file(old_pi_file, "PIs/PI.csv")


def add_institution(dataframe: pandas.DataFrame):
"""Determine every PI's institution name, logging any PI whose institution cannot be determined
This is performed by `get_institution_from_pi()`, which tries to match the PI's username to
Expand Down
4 changes: 4 additions & 0 deletions process_report/tests/unit_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,10 @@ def test_apply_credit_0002(self):
self.assertEqual(0, credited_projects.loc[4, "Balance"])
self.assertEqual(800, credited_projects.loc[5, "Balance"])

updated_old_pi_answer = "PI2,2023-09\nPI3,2024-02\nPI4,2024-03\nPI1,2024-03\n"
with open(self.old_pi_file, "r") as f:
self.assertEqual(updated_old_pi_answer, f.read())


class TestValidateBillables(TestCase):
def setUp(self):
Expand Down

0 comments on commit ec07b44

Please sign in to comment.