99from stanza .pipeline ._constants import *
1010from stanza .pipeline .processor import UDProcessor , register_processor
1111
12+ import torch
13+
1214def extract_text (document , sent_id , start_word , end_word ):
1315 sentence = document .sentences [sent_id ]
1416 tokens = []
@@ -128,6 +130,11 @@ def process(self, document):
128130 best_span = None
129131 max_propn = 0
130132 for span_idx , span in enumerate (span_cluster ):
133+ word_idx = results .word_clusters [cluster_idx ][span_idx ]
134+ is_zero = zero_nodes_created .get ((cluster_idx , word_idx ))
135+ if is_zero :
136+ continue
137+
131138 sent_id = sent_ids [span [0 ]]
132139 sentence = sentences [sent_id ]
133140 start_word = word_pos [span [0 ]]
@@ -145,21 +152,33 @@ def process(self, document):
145152 max_propn = num_propn
146153
147154 mentions = []
148- for span in span_cluster :
149- sent_id = sent_ids [span [0 ]]
150- start_word = word_pos [span [0 ]]
151- end_word = word_pos [span [1 ]- 1 ] + 1
152- mentions .append (CorefMention (sent_id , start_word , end_word ))
153-
154- # Add zero node mentions to this cluster if any exist
155- for zero_cluster_idx , zero_sent_id , zero_word_decimal_id in zero_nodes_created :
156- if zero_cluster_idx == cluster_idx :
157- # Zero node is a single "word" mention at the decimal position
158- import math
159- end_word = math .floor (zero_word_decimal_id ) + 1
160- mentions .append (CorefMention (zero_sent_id , zero_word_decimal_id , end_word ))
161- representative = mentions [best_span ]
162- representative_text = extract_text (document , representative .sentence , representative .start_word , representative .end_word )
155+ for span_idx , span in enumerate (span_cluster ):
156+ word_idx = results .word_clusters [cluster_idx ][span_idx ]
157+ is_zero = zero_nodes_created .get ((cluster_idx , word_idx ))
158+ if is_zero :
159+ (sent_id , zero_word_id ) = is_zero
160+ # if the word id is a tuple, it will be attached
161+ # to the zero
162+ mentions .append (
163+ CorefMention (
164+ sent_id ,
165+ zero_word_id ,
166+ zero_word_id
167+ )
168+ )
169+ else :
170+ sent_id = sent_ids [span [0 ]]
171+ start_word = word_pos [span [0 ]]
172+ end_word = word_pos [span [1 ]- 1 ] + 1
173+ mentions .append (CorefMention (sent_id , start_word , end_word ))
174+
175+ # if we ended up with no best span, then our "representative text"
176+ # is just underscore
177+ if best_span :
178+ representative = mentions [best_span ]
179+ representative_text = extract_text (document , representative .sentence , representative .start_word , representative .end_word )
180+ else :
181+ representative_text = "_"
163182
164183 chain = CorefChain (len (clusters ), mentions , representative_text , best_span )
165184 clusters .append (chain )
@@ -173,15 +192,26 @@ def _handle_zero_anaphora(self, document, results, sent_ids, word_pos):
173192 return
174193
175194 zero_scores = results .zero_scores .squeeze (- 1 ) if results .zero_scores .dim () > 1 else results .zero_scores
195+ is_zero = []
176196
177197 # Flatten word_clusters to get the word indices that correspond to zero_scores
178198 cluster_word_ids = []
179- for cluster in results .word_clusters :
199+ cluster_mapping = {}
200+ counter = 0
201+ for indx , cluster in enumerate (results .word_clusters ):
202+ for _ in range (len (cluster )):
203+ cluster_mapping [counter ] = indx
204+ counter += 1
180205 cluster_word_ids .extend (cluster )
181206
182207 # Find indices where zero_scores > 0
183- zero_indices = (zero_scores > 0 ).nonzero (as_tuple = True )[0 ]
184-
208+ print (zero_scores )
209+ zero_indices = (zero_scores > 0.0 ).nonzero ()
210+
211+ # this dict maps (cluster_id, word_id) to (cluster_id, start, end)
212+ # which overrides span_clusters
213+ zero_to_coref = {}
214+
185215 for zero_idx in zero_indices :
186216 zero_idx = zero_idx .item ()
187217 if zero_idx >= len (cluster_word_ids ):
@@ -193,17 +223,21 @@ def _handle_zero_anaphora(self, document, results, sent_ids, word_pos):
193223
194224 # Create zero node - attach BEFORE the current word
195225 # This means the zero node comes after word_id-1 but before word_id
196- if word_id > 0 :
197- zero_word_id = (word_id , 1 ) # attach after word_id-1, before word_id
198- zero_word = Word (document .sentences [sent_id ], {
199- "text" : "_" ,
200- "lemma" : "_" ,
201- "id" : zero_word_id
202- })
203- document .sentences [sent_id ]._empty_words .append (zero_word )
204-
205- # Track this zero node for adding to coreference clusters
206- cluster_idx , _ = cluster_mapping [zero_idx ]
207- zero_nodes_created .append ((cluster_idx , sent_id , word_id + 0.1 ))
208-
209- return zero_nodes_created
226+ zero_word_id = (
227+ word_id ,
228+ len (document .sentences [sent_id ]._empty_words )+ 1
229+ ) # attach after word_id-1, before word_id
230+ zero_word = Word (document .sentences [sent_id ], {
231+ "text" : "_" ,
232+ "lemma" : "_" ,
233+ "id" : zero_word_id
234+ })
235+ document .sentences [sent_id ]._empty_words .append (zero_word )
236+
237+ # Track this zero node for adding to coreference clusters
238+ cluster_idx = cluster_mapping [zero_idx ]
239+ zero_to_coref [(cluster_idx , word_idx )] = (
240+ sent_id , zero_word_id
241+ )
242+
243+ return zero_to_coref
0 commit comments