Skip to content

Commit

Permalink
Merge branch 'main' of github.com:QUMIA/train-scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
kretep committed Nov 21, 2024
2 parents 8a03fe2 + 74c34fe commit 410e62e
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions qumia_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,15 +145,21 @@ def validate(trainer: QUMIA_Trainer, n_batches=None, set_type='validation', fold
print("Possible mismatch between labels and inputs!")
#raise Exception("Mismatch between labels and inputs")

# Save the dataframe to a csv file
# Prepare the output directory
val_output_dir = os.path.join(trainer.output_dir, folder)
os.makedirs(val_output_dir, exist_ok=True)

# Save the dataframe to a csv file
df_combined.to_csv(os.path.join(val_output_dir, f'df_{set_type}_predictions.csv'), index=False)

# Create a confusion matrix
create_confusion_matrix(rounded_predictions.tolist(), labels.tolist(), set_type, val_output_dir)

# WandB confusion matrix
label_list = [value - 1 for value in labels.astype(int)]
pred_list = [value - 1 for value in rounded_predictions.astype(int)]
wandb.log({"cm_" + folder: wandb.plot.confusion_matrix(probs=None,
y_true=labels, preds=predictions,
y_true=label_list, preds=pred_list,
class_names=['1.0', '2.0', '3.0', '4.0'])})

return df_combined
Expand Down

0 comments on commit 410e62e

Please sign in to comment.