Skip to content

Commit

Permalink
fix wrong mention_score being assigned to antecedent mentions
Browse files Browse the repository at this point in the history
  • Loading branch information
Aethor committed Jun 12, 2024
1 parent 374b123 commit d6729c1
Showing 1 changed file with 16 additions and 14 deletions.
30 changes: 16 additions & 14 deletions tibert/bertcoref.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,8 +833,8 @@ class BertCoreferenceResolutionOutput:
# (batch_size, top_mentions_nb)
top_mentions_index: torch.Tensor

# (batch_size, top_mentions_nb)
top_mentions_scores: torch.Tensor
# (batch_size, spans_nb)
mentions_scores: torch.Tensor

# (batch_size, top_mentions_nb, antecedents_nb)
top_antecedents_index: torch.Tensor
Expand Down Expand Up @@ -873,38 +873,40 @@ def coreference_documents(

G = nx.Graph()
for m_j in range(top_mentions_nb):
span_idx = int(self.top_mentions_index[b_i][m_j].item())
span_coords = spans_idx[span_idx]

top_antecedent_idx = int(antecedents_idx[b_i][m_j].item())
span_i = int(self.top_mentions_index[b_i][m_j].item())
span_coords = spans_idx[span_i]

mention_score = float(self.top_mentions_scores[b_i][m_j].item())
mention_score = float(self.mentions_scores[b_i][span_i].item())
span_mention = Mention(
tokens[b_i][span_coords[0] : span_coords[1]],
span_coords[0],
span_coords[1],
mention_score=mention_score,
)

# index of the best antecedent in self.top_antecedent_index
top_antecedent_idx = int(antecedents_idx[b_i][m_j].item())

# the antecedent is the dummy mention : maybe we have
# a one-mention chain ?
if top_antecedent_idx == antecedents_nb - 1:
if float(self.top_mentions_scores[b_i][m_j].item()) > 0.0:
if float(self.mentions_scores[b_i][span_i].item()) > 0.0:
G.add_node(span_mention)
continue

antecedent_idx = int(
antecedent_span_i = int(
self.top_antecedents_index[b_i][m_j][top_antecedent_idx].item()
)
antecedent_coords = spans_idx[antecedent_span_i]

antecedent_coords = spans_idx[antecedent_idx]

mention_score = float(self.top_mentions_scores[b_i][m_j].item())
antecedent_mention_score = float(
self.mentions_scores[b_i][antecedent_span_i].item()
)
antecedent_mention = Mention(
tokens[b_i][antecedent_coords[0] : antecedent_coords[1]],
antecedent_coords[0],
antecedent_coords[1],
mention_score=mention_score,
mention_score=antecedent_mention_score,
)

G.add_node(antecedent_mention)
Expand Down Expand Up @@ -1510,7 +1512,7 @@ def forward(
return BertCoreferenceResolutionOutput(
final_scores,
top_mentions_index,
top_mention_scores,
mention_scores,
top_antecedents_index,
self.config.max_span_size,
loss=loss,
Expand Down

0 comments on commit d6729c1

Please sign in to comment.