File tree Expand file tree Collapse file tree 1 file changed +26
-0
lines changed Expand file tree Collapse file tree 1 file changed +26
-0
lines changed Original file line number Diff line number Diff line change @@ -667,6 +667,32 @@ def predict(
667
667
668
668
return loss if return_loss else None
669
669
670
+ def _print_predictions (self , batch , gold_label_type : str ) -> list [str ]:
671
+ lines = []
672
+ for datapoint in batch :
673
+ # check if there is a label mismatch
674
+ g = [label .labeled_identifier for label in datapoint .get_labels (gold_label_type )]
675
+ p = [label .labeled_identifier for label in datapoint .get_labels ("predicted" )]
676
+ g .sort ()
677
+ p .sort ()
678
+
679
+ # if the gold label is O and is correctly predicted as no label, do not print out as this clutters
680
+ # the output file with trivial predictions
681
+ if not (
682
+ len (datapoint .get_labels (gold_label_type )) == 1
683
+ and datapoint .get_label (gold_label_type ).value == "O"
684
+ and len (datapoint .get_labels ("predicted" )) == 0
685
+ ):
686
+ correct_string = " -> MISMATCH!\n " if g != p else ""
687
+ eval_line = (
688
+ f"{ datapoint .text } \n "
689
+ f" - Gold: { ', ' .join (label .value if label .data_point == datapoint else label .labeled_identifier for label in datapoint .get_labels (gold_label_type ))} \n "
690
+ f" - Pred: { ', ' .join (label .value if label .data_point == datapoint else label .labeled_identifier for label in datapoint .get_labels ('predicted' ))} \n "
691
+ f"{ correct_string } \n "
692
+ )
693
+ lines .append (eval_line )
694
+ return lines
695
+
670
696
def _get_state_dict (self ) -> dict [str , Any ]:
671
697
model_state : dict [str , Any ] = {
672
698
** super ()._get_state_dict (),
You can’t perform that action at this time.
0 commit comments