Skip to content

Commit 8bc9c28

Browse files
authored
Merge pull request #3591 from flairNLP/relation_printout
Modify printouts in RelationClassifier evaluation to remove clutter
2 parents 68508cc + 58e903c commit 8bc9c28

File tree

1 file changed

+26
-0
lines changed

1 file changed

+26
-0
lines changed

flair/models/relation_classifier_model.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -667,6 +667,32 @@ def predict(
667667

668668
return loss if return_loss else None
669669

670+
def _print_predictions(self, batch, gold_label_type: str) -> list[str]:
671+
lines = []
672+
for datapoint in batch:
673+
# check if there is a label mismatch
674+
g = [label.labeled_identifier for label in datapoint.get_labels(gold_label_type)]
675+
p = [label.labeled_identifier for label in datapoint.get_labels("predicted")]
676+
g.sort()
677+
p.sort()
678+
679+
# if the gold label is O and is correctly predicted as no label, do not print out as this clutters
680+
# the output file with trivial predictions
681+
if not (
682+
len(datapoint.get_labels(gold_label_type)) == 1
683+
and datapoint.get_label(gold_label_type).value == "O"
684+
and len(datapoint.get_labels("predicted")) == 0
685+
):
686+
correct_string = " -> MISMATCH!\n" if g != p else ""
687+
eval_line = (
688+
f"{datapoint.text}\n"
689+
f" - Gold: {', '.join(label.value if label.data_point == datapoint else label.labeled_identifier for label in datapoint.get_labels(gold_label_type))}\n"
690+
f" - Pred: {', '.join(label.value if label.data_point == datapoint else label.labeled_identifier for label in datapoint.get_labels('predicted'))}\n"
691+
f"{correct_string}\n"
692+
)
693+
lines.append(eval_line)
694+
return lines
695+
670696
def _get_state_dict(self) -> dict[str, Any]:
671697
model_state: dict[str, Any] = {
672698
**super()._get_state_dict(),

0 commit comments

Comments
 (0)