Skip to content

Commit

Permalink
Merge pull request #3 from jaymedina/extract-gs-file
Browse files Browse the repository at this point in the history
Logic to extract the goldstandard file from folder
  • Loading branch information
vpchung authored May 7, 2024
2 parents af429ae + f5d5a3c commit 7acb610
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 6 deletions.
17 changes: 15 additions & 2 deletions score.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,18 @@

import argparse
import json
import os

import pandas as pd
from sklearn.metrics import roc_auc_score, average_precision_score

from glob import glob

def get_args():
"""Set up command-line interface and get arguments."""
parser = argparse.ArgumentParser()
parser.add_argument("-p", "--predictions_file", type=str, required=True)
parser.add_argument("-g", "--goldstandard_file", type=str, required=True)
parser.add_argument("-g", "--goldstandard_folder", type=str, required=True)
parser.add_argument("-o", "--output", type=str, default="results.json")
return parser.parse_args()

Expand All @@ -31,16 +33,27 @@ def score(gold, gold_col, pred, pred_col):
return {"auc_roc": roc, "auprc": pr}


def extract_gs_file(folder):
"""Extract gold standard file from folder."""
files = glob(os.path.join(folder, "*"))
if len(files) != 1:
raise ValueError(f"Expected exactly one gold standard file in folder. Got {len(files)}. Exiting.")

return files[0]


def main():
"""Main function."""
args = get_args()

with open(args.output, encoding="utf-8") as out:
res = json.load(out)

gold_file = extract_gs_file(args.goldstandard_folder)

if res.get("validation_status") == "VALIDATED":
pred = pd.read_csv(args.predictions_file)
gold = pd.read_csv(args.goldstandard_file)
gold = pd.read_csv(gold_file)
scores = score(gold, "disease", pred, "disease_probability")
status = "SCORED"
else:
Expand Down
21 changes: 17 additions & 4 deletions validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@

import argparse
import json
import os

import pandas as pd
import numpy as np
import pandas as pd

from glob import glob

EXPECTED_COLS = {
'id': str,
Expand All @@ -23,7 +26,7 @@ def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("-p", "--predictions_file",
type=str, required=True)
parser.add_argument("-g", "--goldstandard_file",
parser.add_argument("-g", "--goldstandard_folder",
type=str, required=True)
parser.add_argument("-o", "--output",
type=str, default="results.json")
Expand Down Expand Up @@ -82,10 +85,20 @@ def check_prob_values(pred):
return ""


def validate(gold_file, pred_file):
def extract_gs_file(folder):
"""Extract gold standard file from folder."""
files = glob(os.path.join(folder, "*"))
if len(files) != 1:
raise ValueError(f"Expected exactly one gold standard file in folder. Got {len(files)}. Exiting.")

return files[0]


def validate(gold_folder, pred_file):
"""Validate predictions file against goldstandard."""
errors = []

gold_file = extract_gs_file(gold_folder)
gold = pd.read_csv(gold_file, index_col="id")
try:
pred = pd.read_csv(
Expand Down Expand Up @@ -117,7 +130,7 @@ def main():
errors = [f.read()]
else:
errors = validate(
gold_file=args.goldstandard_file,
gold_folder=args.goldstandard_folder,
pred_file=args.predictions_file
)

Expand Down

0 comments on commit 7acb610

Please sign in to comment.