Skip to content

Commit

Permalink
Merge pull request #4 from Sage-Bionetworks-Challenges/update-evaluation
Browse files Browse the repository at this point in the history
Update expected colnames and metric calculations
  • Loading branch information
vpchung authored Jul 3, 2024
2 parents aa69eca + eda41f7 commit f610c7c
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 17 deletions.
36 changes: 26 additions & 10 deletions score.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,17 @@
- ROC curve
- PR curve
"""
from glob import glob
import argparse
import json
import os
from glob import glob

import pandas as pd
import numpy as np
from sklearn.metrics import roc_auc_score, average_precision_score
import pandas as pd
from sklearn.metrics import average_precision_score, roc_auc_score

GOLDSTANDARD_COLS = {"id": str, "disease": int}
PREDICTION_COLS = {"id": str, "disease_probability": np.float64}
GOLDSTANDARD_COLS = {"epr_number": str, "disease_probability": str}
PREDICTION_COLS = {"epr_number": str, "disease_probability": np.float64}


def get_args():
Expand All @@ -27,12 +27,21 @@ def get_args():
return parser.parse_args()


def score(gold, gold_col, pred, pred_col):
def score(gold, pred, id_colname, prob_colname):
"""
Calculate metrics for: AUC-ROC, AUCPR
"""
roc = roc_auc_score(gold[gold_col], pred[pred_col])
pr = average_precision_score(gold[gold_col], pred[pred_col])
# Join the two dataframes so that the order of the ids are the same
# between goldstandard and prediction.
merged = gold.merge(pred, how="left", on=id_colname)
roc = roc_auc_score(
merged[prob_colname + "_x"],
merged[prob_colname + "_y"]
)
pr = average_precision_score(
merged[prob_colname + "_x"],
merged[prob_colname + "_y"]
)
return {"auc_roc": roc, "auprc": pr}


Expand All @@ -44,10 +53,16 @@ def extract_gs_file(folder):
"Expected exactly one gold standard file in folder. "
f"Got {len(files)}. Exiting."
)

return files[0]


def preprocess(df, colname):
"""Preprocess dataframe and convert column as needed."""
df = df[~df[colname].isin([".M"])]
df[colname] = df[colname].astype(int)
return df


def main():
"""Main function."""
args = get_args()
Expand All @@ -71,7 +86,8 @@ def main():
usecols=GOLDSTANDARD_COLS,
dtype=GOLDSTANDARD_COLS
)
scores = score(gold, "disease", pred, "disease_probability")
gold = preprocess(gold, "disease_probability")
scores = score(gold, pred, "epr_number", "disease_probability")
status = "SCORED"
errors = ""
except ValueError:
Expand Down
14 changes: 7 additions & 7 deletions validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
import numpy as np
import pandas as pd

GOLDSTANDARD_COLS = {"id": str, "disease": int}
EXPECTED_COLS = {"id": str, "disease_probability": np.float64}
GOLDSTANDARD_COLS = {"epr_number": str, "disease_probability": str}
EXPECTED_COLS = {"epr_number": str, "disease_probability": np.float64}


def get_args():
Expand All @@ -28,18 +28,18 @@ def get_args():

def check_dups(pred):
"""Check for duplicate participant IDs."""
duplicates = pred.duplicated(subset=["id"])
duplicates = pred.duplicated(subset=["epr_number"])
if duplicates.any():
return (
f"Found {duplicates.sum()} duplicate ID(s): "
f"{pred[duplicates].id.to_list()}"
f"{pred[duplicates].epr_number.to_list()}"
)
return ""


def check_missing_ids(gold, pred):
"""Check for missing participant IDs."""
pred = pred.set_index("id")
pred = pred.set_index("epr_number")
missing_ids = gold.index.difference(pred.index)
if missing_ids.any():
return (
Expand All @@ -51,7 +51,7 @@ def check_missing_ids(gold, pred):

def check_unknown_ids(gold, pred):
"""Check for unknown participant IDs."""
pred = pred.set_index("id")
pred = pred.set_index("epr_number")
unknown_ids = pred.index.difference(gold.index)
if unknown_ids.any():
return (
Expand Down Expand Up @@ -92,7 +92,7 @@ def validate(gold_folder, pred_file):
"""Validate predictions file against goldstandard."""
errors = []
gold_file = extract_gs_file(gold_folder)
gold = pd.read_csv(gold_file, dtype=GOLDSTANDARD_COLS, index_col="id")
gold = pd.read_csv(gold_file, dtype=GOLDSTANDARD_COLS, index_col="epr_number")
try:
pred = pd.read_csv(
pred_file,
Expand Down

0 comments on commit f610c7c

Please sign in to comment.