From b1d77bda3488537d3ab9de945d22548665ac6576 Mon Sep 17 00:00:00 2001 From: Aethor Date: Wed, 28 Aug 2024 17:47:10 +0200 Subject: [PATCH 01/13] inference between pairs of spans is now done at the token level instead than at the wordpiece level --- tibert/bertcoref.py | 290 ++++++++++++++++++++++++-------------------- tibert/predict.py | 58 ++------- 2 files changed, 167 insertions(+), 181 deletions(-) diff --git a/tibert/bertcoref.py b/tibert/bertcoref.py index 77bf456..bf43c2d 100644 --- a/tibert/bertcoref.py +++ b/tibert/bertcoref.py @@ -16,12 +16,10 @@ from dataclasses import dataclass from sacremoses import MosesTokenizer from more_itertools.recipes import flatten -import numpy as np import torch from torch.nn.parameter import Parameter from torch.utils.data import Dataset -from torch.utils.data.dataloader import DataLoader -from transformers import PreTrainedModel, BertPreTrainedModel # type: ignore +from transformers import BertPreTrainedModel # type: ignore from transformers import PreTrainedTokenizerFast # type: ignore from transformers.file_utils import PaddingStrategy from transformers.data.data_collator import DataCollatorMixin @@ -157,7 +155,10 @@ def document_labels(self, max_span_size: int) -> Tuple[torch.Tensor, torch.Tenso return (self.coref_labels(max_span_size), self.mention_labels(max_span_size)) def prepared_document( - self, tokenizer: PreTrainedTokenizerFast, max_span_size: int + self, + tokenizer: PreTrainedTokenizerFast, + max_span_size: int, + mode: Literal["train", "test"] = "train", ) -> Tuple[CoreferenceDocument, BatchEncoding]: """Prepare a document for being inputted into a model. The document is retokenized thanks to ``tokenizer``, and @@ -169,6 +170,10 @@ def prepared_document( method. This means special tokens will be added. :param tokenizer: tokenizer used to retokenized the document + :param max_span_size: + :param mode: if 'test', dont include labels in the returned + :class:`BatchEncoding` + :return: a tuple : - a new :class:`CoreferenceDocument` @@ -222,64 +227,14 @@ def prepared_document( new_chains.append(new_chain) document = CoreferenceDocument(tokens, new_chains) - coref_labels, mention_labels = document.document_labels(max_span_size) - batch["coref_labels"] = coref_labels - batch["mention_labels"] = mention_labels + if mode == "train": + coref_labels, mention_labels = self.document_labels(max_span_size) + batch["coref_labels"] = coref_labels + batch["mention_labels"] = mention_labels + batch["word_ids"] = torch.tensor([-1 if i is None else i for i in words_ids]) return document, batch - def from_wpieced_to_tokenized( - self, tokens: List[str], wp_to_token: List[int] - ) -> CoreferenceDocument: - """Convert the current document, tokenized with wordpieces, to - a document 'normally' tokenized - - :param tokens: the original tokens - :param wp_to_token: mapping from wordpiece index to token index - - :return: - """ - # In some cases, output mentions can be produced several - # times. This can happen when a mention boundary is expressed - # by several wordpiece (for example the wordpiece mentions : - # "l ' abbé" and "l '" can both correspond to the regular - # mention "l 'abbé". For those cases, we keep only one choice - # and assume that predictions for these mentions are - # consistent. - already_visited_mentions = set() - - new_chains = [] - - for chain in self.coref_chains: - new_chain = [] - - for mention in chain: - new_start_idx = wp_to_token[mention.start_idx] - new_end_idx = wp_to_token[mention.end_idx - 1] - # NOTE: this happens in case the model has predicted - # an erroneous mention such as '[CLS]' or '[SEP]'. In - # that case, we simply ignore the mention. - if new_start_idx is None or new_end_idx is None: - continue - new_end_idx += 1 - - new_mention = Mention( - tokens[new_start_idx:new_end_idx], - new_start_idx, - new_end_idx, - ) - - if new_mention in already_visited_mentions: - continue - - new_chain.append(new_mention) - - already_visited_mentions.add(new_mention) - - new_chains.append(new_chain) - - return CoreferenceDocument(tokens, new_chains) - @staticmethod def from_labels( tokens: List[str], @@ -365,20 +320,6 @@ class DataCollatorForSpanClassification(DataCollatorMixin): return_tensors: Literal["pt"] = "pt" def torch_call(self, features) -> Union[dict, BatchEncoding]: - coref_labels = ( - [feature["coref_labels"] for feature in features] - if "coref_labels" in features[0].keys() - else None - ) - mention_labels = ( - [feature["mention_labels"] for feature in features] - if "mention_labels" in features[0].keys() - else None - ) - assert (coref_labels is None and mention_labels is None) or ( - coref_labels and mention_labels - ) - warning_state = self.tokenizer.deprecation_warnings.get( "Asking-to-pad-a-fast-tokenizer", False ) @@ -387,51 +328,59 @@ def torch_call(self, features) -> Union[dict, BatchEncoding]: features, padding=self.padding, max_length=self.max_length, - # Conversion to tensors will fail if we have labels as - # they are not of the same length yet. - return_tensors="pt" if coref_labels is None else None, + # Conversion to tensors will fail as they are not of the + # same length yet. + return_tensors=None, + ) + self.tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = ( + warning_state ) - self.tokenizer.deprecation_warnings[ - "Asking-to-pad-a-fast-tokenizer" - ] = warning_state # keep encoding info batch._encodings = [f.encodings[0] for f in features] - if coref_labels is None: - return batch - - documents = [ - CoreferenceDocument.from_labels( - tokens, coref_labels, mention_labels, max_span_size=self.max_span_size - ) - for tokens, coref_labels, mention_labels in zip( - [f["input_ids"] for f in features], - [f["coref_labels"] for f in features], - [f["mention_labels"] for f in features], - ) - ] + device = torch.device(self.device) - for document, tokens in zip(documents, batch["input_ids"]): # type: ignore - document.tokens = tokens - labels = [doc.document_labels(self.max_span_size) for doc in documents] + if "coref_labels" in batch: + # TODO: if we can find a better way to pad sparse tensors, + # that would be great... At least the big tensor never finds + # its way onto the GPU. + max_p = max([tens.shape[0] for tens in batch["coref_labels"]]) + for tens_i, tens in enumerate(batch["coref_labels"]): + p = tens.shape[0] + if max_p > p: + tens = torch.cat([tens.to_dense(), torch.zeros(max_p - p, p + 1)]) + batch["coref_labels"][tens_i] = torch.cat( + [ + tens[:, :-1], + torch.zeros(max_p, max_p - p), + tens[:, -1].unsqueeze(1), + ], + dim=1, + ).to_sparse_coo() + batch["coref_labels"] = torch.stack( + [tens for tens in batch["coref_labels"]] + ).to(device) + + batch["mention_labels"] = torch.nn.utils.rnn.pad_sequence( + batch["mention_labels"], batch_first=True + ).to(device) + + batch["word_ids"] = torch.nn.utils.rnn.pad_sequence( + batch["word_ids"], batch_first=True, padding_value=-1 + ).to(device) - device = torch.device(self.device) - del batch["coref_labels"] - del batch["mention_labels"] batch = BatchEncoding( { - k: torch.tensor(v, dtype=torch.int64, device=device) + k: ( + torch.tensor(v, dtype=torch.int64, device=device) + if not k in ("coref_labels", "mention_labels", "word_ids") + else v + ) for k, v in batch.items() }, encoding=batch.encodings, ) - batch["coref_labels"] = torch.stack( - [coref_labels for coref_labels, _ in labels] - ).to(device) - batch["mention_labels"] = torch.stack( - [mention_labels for _, mention_labels in labels] - ).to(device) return batch @@ -441,6 +390,7 @@ class CoreferenceDataset(Dataset): :ivar documents: :ivar tokenizer: :ivar max_span_len: + :ivar mode: """ def __init__( @@ -453,6 +403,13 @@ def __init__( self.documents = documents self.tokenizer = tokenizer self.max_span_size = max_span_size + self.mode: Literal["train", "test"] = "train" + + def set_test_(self): + self.mode = "test" + + def set_train_(self): + self.mode = "train" @staticmethod def from_conll2012_file( @@ -676,7 +633,9 @@ def __len__(self) -> int: def __getitem__(self, index: int) -> BatchEncoding: document = self.documents[index] - _, batch = document.prepared_document(self.tokenizer, self.max_span_size) + _, batch = document.prepared_document( + self.tokenizer, self.max_span_size, self.mode + ) return batch @@ -850,23 +809,39 @@ def load_ontonotes_dataset( @dataclass class BertCoreferenceResolutionOutput: - # (batch_size, top_mentions_nb, antecedents_nb) + """Output of BertForCoreferenceResolution + + .. note :: + + We use the following short notation to annotate shapes : + + - b: batch_size + - s: seq_size (in wordpieces) + - w: seq_size_words (in words) + - p: spans_nb + - m: top_mentions_nb + - a: antecedents_nb + - h: hidden_size + - t: metadatas_features_size + """ + + # (b, m, a) logits: torch.Tensor - # (batch_size, top_mentions_nb) + # (b, m) top_mentions_index: torch.Tensor - # (batch_size, spans_nb) + # (b, p) mentions_scores: torch.Tensor - # (batch_size, top_mentions_nb, antecedents_nb) + # (b, m, a) top_antecedents_index: torch.Tensor max_span_size: int loss: Optional[torch.Tensor] = None - # (batch_size, seq_size, hidden_size) + # (b, w, h) hidden_states: Optional[torch.FloatTensor] = None def coreference_documents( @@ -875,7 +850,7 @@ def coreference_documents( """Extract a :class:`.CoreferenceDocument` list from a coreference output. - :param tokens: + :param tokens: the original tokens :return: a list of :class:`.CoreferenceDocument`, one per batch @@ -977,7 +952,8 @@ class BertForCoreferenceResolution(BertPreTrainedModel): We use the following short notation to annotate shapes : - b: batch_size - - s: seq_size + - s: seq_size (in wordpieces) + - w: seq_size_words (in words) - p: spans_nb - m: top_mentions_nb - a: antecedents_nb @@ -1100,7 +1076,7 @@ def mention_compatibility_score( return self.mention_compatibility_scorer(score).squeeze(-1) def pruned_mentions_indexs( - self, mention_scores: torch.Tensor, seq_size: int, top_mentions_nb: int + self, mention_scores: torch.Tensor, words_nb: int, top_mentions_nb: int ) -> torch.Tensor: """Prune mentions, keeping only the k non-overlapping best of them @@ -1116,7 +1092,7 @@ def pruned_mentions_indexs( :param mention_scores: a tensor of shape ``(b, p)`` - :param seq_size: + :param words_nb: :param top_mentions_nb: the maximum number of spans to keep during the pruning process @@ -1128,7 +1104,7 @@ def pruned_mentions_indexs( assert top_mentions_nb <= spans_nb - spans_idx = spans_indexs(list(range(seq_size)), self.config.max_span_size) + spans_idx = spans_indexs(list(range(words_nb)), self.config.max_span_size) def spans_are_overlapping( span1: Tuple[int, int], span2: Tuple[int, int] @@ -1168,11 +1144,11 @@ def spans_are_overlapping( return mention_indexs - def distance_between_spans(self, spans_nb: int, seq_size: int) -> torch.Tensor: + def distance_between_spans(self, spans_nb: int, words_nb: int) -> torch.Tensor: """Compute the indexs of the k closest mentions :param spans_nb: number of spans in the sequence - :param seq_size: size of the sequence + :param words_nb: number of words in the sequence :return: a tensor of shape ``(p, p)`` """ p = spans_nb @@ -1180,7 +1156,7 @@ def distance_between_spans(self, spans_nb: int, seq_size: int) -> torch.Tensor: # a list of spans indices # [(start, end), ..., (start, end)] - spans_idx = spans_indexs(list(range(seq_size)), self.config.max_span_size) + spans_idx = spans_indexs(list(range(words_nb)), self.config.max_span_size) # (spans_nb,) start_idx = torch.tensor([start for start, _ in spans_idx]).to(device) @@ -1200,16 +1176,16 @@ def distance_between_spans(self, spans_nb: int, seq_size: int) -> torch.Tensor: return dist def closest_antecedents_indexs( - self, spans_nb: int, seq_size: int, antecedents_nb: int + self, spans_nb: int, words_nb: int, antecedents_nb: int ): """Compute the indexs of the k closest mentions :param spans_nb: number of spans in the sequence - :param seq_size: size of the sequence + :param words_nb: number of words in the sequence :param antecedents_nb: number of antecedents to consider :return: a tensor of shape ``(p, a)`` """ - dist = self.distance_between_spans(spans_nb, seq_size) + dist = self.distance_between_spans(spans_nb, words_nb) assert dist.shape == (spans_nb, spans_nb) # when the distance between a span and a possible antecedent @@ -1229,24 +1205,24 @@ def distance_feature( top_antecedents_index: torch.Tensor, top_mentions_index: torch.Tensor, spans_nb: int, - seq_size: int, + words_nb: int, ) -> torch.Tensor: """Compute the distance feature between two spans :param top_antecedents_index: ``(b, m, a)`` :param top_mentions_index: ``(b, m)`` :param spans_nb: - :param seq_size: + :param words_nb: number of words in the sequence :return: ``(b, m, a, t)`` """ b, m, a = top_antecedents_index.shape t = self.config.metadatas_features_size p = spans_nb - s = seq_size + w = words_nb device = next(self.parameters()).device - dist = self.distance_between_spans(p, s) + dist = self.distance_between_spans(p, w) dist = dist.unsqueeze(0).repeat(b, 1, 1) assert dist.shape == (b, p, p) @@ -1287,7 +1263,6 @@ def bert_encode( :return: hidden states of the last layer, of shape ``(b, s, h)`` """ - # list[(batch_size, <= segment_size, hidden_size)] last_hidden_states = [] @@ -1312,6 +1287,48 @@ def maybe_take_segment( return torch.cat(last_hidden_states, dim=1) + def wordreduce_embeddings( + self, encoded: torch.Tensor, word_ids: torch.LongTensor + ) -> torch.Tensor: + """ + :param encoded: an embedding tensor of shape ``(b, s, h)`` + :param word_ids: a tensor of shape ``(b, s)`` + :return: an embedding tensor of shape ``(b, w, h)`` + """ + device = next(self.parameters()).device + b, _, h = encoded.shape + + word_encoded = [] + for b_i in range(b): + # here is what we want to achieve: + # + # word_ids ~ [0 0 0 -1 1 1 2 2 ] + # encoded ~ [A B C D E F G H ] + # => [mean(ABC) mean(EF) mean(GH)] + # + # To do so we first filter encoded to remove -1s since + # they dont correspond to any tokens. Then, we use + # torch.index_reduce to achieve mean reduction. At the + # time of this writing, torch.index_reduce is in beta so + # it will print out a warning! Unfortunately there is no + # batched version of torch.index_reduce as far as I + # understand, so we do it in a loop :< + batch_word_ids = word_ids[b_i] + token_mask = batch_word_ids != -1 + batch_encoded = encoded[b_i][token_mask] + batch_word_ids = batch_word_ids[token_mask] + words_nb = len(set(batch_word_ids.tolist())) + words = torch.zeros(words_nb, h, device=device) + words.index_reduce_( + 0, batch_word_ids, batch_encoded, "mean", include_self=False + ) + word_encoded.append(words) + + # each 2D-tensor in word_encoded is of shape (w_i, h). w_i + # varies depending on the batch, so we have to do some + # padding! + return torch.nn.utils.rnn.pad_sequence(word_encoded, batch_first=True) + def mention_loss( self, top_mention_scores: torch.Tensor, mention_labels: torch.Tensor ) -> torch.Tensor: @@ -1352,6 +1369,7 @@ def coref_loss( def forward( self, input_ids: torch.Tensor, + word_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, @@ -1362,6 +1380,7 @@ def forward( ) -> BertCoreferenceResolutionOutput: """ :param input_ids: a tensor of shape ``(b, s)`` + :param word_ids: a tensor of shape ``(b, s)`` :param attention_mask: a tensor of shape ``(b, s)`` :param token_type_ids: a tensor of shape ``(b, s)`` :param position_ids: a tensor of shape ``(b, s)`` @@ -1385,11 +1404,14 @@ def forward( head_mask=head_mask, ) assert encoded_input.shape == (b, s, h) + encoded_input = self.wordreduce_embeddings(encoded_input, word_ids) + words_nb = w = encoded_input.shape[1] + assert encoded_input.shape == (b, w, h) # -- span bounds computation -- # we select starting and ending bounds of spans of length up # to self.max_span_size - spans_idx = spans(range(seq_size), self.config.max_span_size) + spans_idx = spans(range(words_nb), self.config.max_span_size) spans_nb = p = len(spans_idx) spans_selector = torch.flatten( torch.tensor([[span[0], span[-1]] for span in spans_idx], dtype=torch.long) @@ -1409,9 +1431,11 @@ def forward( # top_mentions_index is the index of the m best # non-overlapping mentions - top_mentions_nb = m = int(self.config.mentions_per_tokens * seq_size) + top_mentions_nb = m = max( + min(1, words_nb), int(self.config.mentions_per_tokens * words_nb) + ) top_mentions_index = self.pruned_mentions_indexs( - mention_scores, seq_size, top_mentions_nb + mention_scores, words_nb, top_mentions_nb ) # TODO: hack top_mentions_nb = m = int(top_mentions_index.shape[1]) @@ -1421,7 +1445,7 @@ def forward( # antecedents for each spans antecedents_nb = a = min(self.config.antecedents_nb, spans_nb) antecedents_index = self.closest_antecedents_indexs( - spans_nb, seq_size, antecedents_nb + spans_nb, words_nb, antecedents_nb ) antecedents_index = torch.tile(antecedents_index, (batch_size, 1, 1)) assert antecedents_index.shape == (b, p, a) @@ -1449,7 +1473,7 @@ def forward( # distance feature computation dist_ft = self.distance_feature( - top_antecedents_index, top_mentions_index, spans_nb, seq_size + top_antecedents_index, top_mentions_index, spans_nb, words_nb ) dist_ft = torch.flatten(dist_ft, start_dim=1, end_dim=2) assert dist_ft.shape == (b, m * a, t) @@ -1502,7 +1526,7 @@ def forward( # NOTE: we have to rely on such a loop, as torch.gather # cannot be used on sparse tensors, which prevents using - # batch_index_select + # batch_index_select. selected_coref_labels = torch.stack( [ torch.index_select(coref_labels[b_i], 0, top_mentions_index[b_i]) diff --git a/tibert/predict.py b/tibert/predict.py index f59765a..98f98a5 100644 --- a/tibert/predict.py +++ b/tibert/predict.py @@ -215,14 +215,7 @@ def _stream_predict_wpieced_coref_raw( lang: str = "en", return_hidden_state: bool = False, ) -> Generator[ - Tuple[ - List[CoreferenceDocument], - List[CoreferenceDocument], - BatchEncoding, - BertCoreferenceResolutionOutput, - ], - None, - None, + Tuple[List[CoreferenceDocument], BertCoreferenceResolutionOutput], None, None ]: """Low level inference interface.""" @@ -248,6 +241,7 @@ def _stream_predict_wpieced_coref_raw( tokenizer, model.config.max_span_size, ) + dataset.set_test_() data_collator = DataCollatorForSpanClassification( tokenizer, model.config.max_span_size, device_str ) @@ -271,14 +265,9 @@ def _stream_predict_wpieced_coref_raw( **batch, return_hidden_state=return_hidden_state ) - out_docs = out.coreference_documents( - [ - [tokenizer.decode(t) for t in input_ids] # type: ignore - for input_ids in batch["input_ids"] - ] - ) + out_docs = out.coreference_documents([doc.tokens for doc in batch_docs]) - yield batch_docs, out_docs, batch, out + yield out_docs, out def stream_predict_coref( @@ -303,16 +292,11 @@ def stream_predict_coref( :return: a list of ``CoreferenceDocument``, with annotated coreference chains. """ - for original_docs, out_docs, batch, out in _stream_predict_wpieced_coref_raw( + for out_docs, _ in _stream_predict_wpieced_coref_raw( documents, model, tokenizer, batch_size, quiet, device_str, lang ): - for batch_i, (original_doc, out_doc) in enumerate(zip(original_docs, out_docs)): - seq_size = batch["input_ids"].shape[1] - wp_to_token = [ - batch.token_to_word(batch_i, token_index=i) for i in range(seq_size) - ] - doc = out_doc.from_wpieced_to_tokenized(original_doc.tokens, wp_to_token) - yield doc + for out_doc in out_docs: + yield out_doc def predict_coref( @@ -344,13 +328,11 @@ def predict_coref( if hierarchical_merging: docs = [] hidden_states = [] - all_tokens = [] - wp_to_token = [] if len(documents) == 0: return None - for original_docs, out_docs, batch, out in _stream_predict_wpieced_coref_raw( + for out_docs, out in _stream_predict_wpieced_coref_raw( documents, model, tokenizer, @@ -365,28 +347,8 @@ def predict_coref( assert not out.hidden_states is None hidden_states += [h for h in out.hidden_states] - all_tokens += list(flatten(doc.tokens for doc in original_docs)) - - batch_size = batch["input_ids"].shape[0] - seq_size = batch["input_ids"].shape[1] - for batch_i in range(batch_size): - # we need to shift the index of tokens in the batch by - # the index of the last token of the previous batch - max_prev_wp_to_token = ( - 0 - if len(wp_to_token) == 0 - else max([wtt for wtt in wp_to_token if not wtt is None], default=0) - ) - for i in range(seq_size): - wtt = batch.token_to_word(batch_i, token_index=i) - if not wtt is None and max_prev_wp_to_token != 0: - wtt += max_prev_wp_to_token + 1 - wp_to_token.append(wtt) - - merged_doc_wpieced = merge_coref_outputs(docs, hidden_states, model, device_str) - assert not merged_doc_wpieced is None # we know that len(docs) > 0 - - return merged_doc_wpieced.from_wpieced_to_tokenized(all_tokens, wp_to_token) + merged_doc = merge_coref_outputs(docs, hidden_states, model, device_str) + return merged_doc return list( stream_predict_coref( From 035e85a98c0a5ed5154e4ada6f2a489e868fb454 Mon Sep 17 00:00:00 2001 From: Aethor Date: Wed, 28 Aug 2024 18:09:46 +0200 Subject: [PATCH 02/13] fix a possible crash in the model output parsing --- tibert/bertcoref.py | 9 ++++++++- tibert/predict.py | 6 +++--- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/tibert/bertcoref.py b/tibert/bertcoref.py index bf43c2d..bbcbd15 100644 --- a/tibert/bertcoref.py +++ b/tibert/bertcoref.py @@ -872,7 +872,14 @@ def coreference_documents( G = nx.Graph() for m_j in range(top_mentions_nb): span_i = int(self.top_mentions_index[b_i][m_j].item()) - span_coords = spans_idx[span_i] + # it is possible to have a top span that does not + # actually exist in a batch sample. This is because + # padding is done on wordpieces but not on words. In + # that case, we simply ignore that predicted span. + try: + span_coords = spans_idx[span_i] + except IndexError: + continue mention_score = float(self.mentions_scores[b_i][span_i].item()) span_mention = Mention( diff --git a/tibert/predict.py b/tibert/predict.py index 98f98a5..34a3ec8 100644 --- a/tibert/predict.py +++ b/tibert/predict.py @@ -205,7 +205,7 @@ def tensorize_chains( return CoreferenceDocument(merged_left.tokens + merged_right.tokens, new_chains) -def _stream_predict_wpieced_coref_raw( +def _stream_predict_coref_raw( documents: List[Union[str, List[str]]], model: BertForCoreferenceResolution, tokenizer: PreTrainedTokenizerFast, @@ -292,7 +292,7 @@ def stream_predict_coref( :return: a list of ``CoreferenceDocument``, with annotated coreference chains. """ - for out_docs, _ in _stream_predict_wpieced_coref_raw( + for out_docs, _ in _stream_predict_coref_raw( documents, model, tokenizer, batch_size, quiet, device_str, lang ): for out_doc in out_docs: @@ -332,7 +332,7 @@ def predict_coref( if len(documents) == 0: return None - for out_docs, out in _stream_predict_wpieced_coref_raw( + for out_docs, out in _stream_predict_coref_raw( documents, model, tokenizer, From 82fd0d06ffa7a4a108dd62d4b4528ef57dc4e57e Mon Sep 17 00:00:00 2001 From: Aethor Date: Wed, 28 Aug 2024 18:23:17 +0200 Subject: [PATCH 03/13] fix another possible crash in the model output parsing --- tibert/bertcoref.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tibert/bertcoref.py b/tibert/bertcoref.py index bbcbd15..a92ac19 100644 --- a/tibert/bertcoref.py +++ b/tibert/bertcoref.py @@ -902,7 +902,10 @@ def coreference_documents( antecedent_span_i = int( self.top_antecedents_index[b_i][m_j][top_antecedent_idx].item() ) - antecedent_coords = spans_idx[antecedent_span_i] + try: + antecedent_coords = spans_idx[antecedent_span_i] + except IndexError: + continue antecedent_mention_score = float( self.mentions_scores[b_i][antecedent_span_i].item() From a1c97fe73907103567f367d3553c9709f07c6ea3 Mon Sep 17 00:00:00 2001 From: Aethor Date: Thu, 29 Aug 2024 17:23:25 +0200 Subject: [PATCH 04/13] add mention detection metrics reporting for train/test scripts --- tibert/run_test.py | 15 +++++++++++++-- tibert/run_train.py | 12 ++++++++++++ 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/tibert/run_test.py b/tibert/run_test.py index d888d7b..904ef57 100644 --- a/tibert/run_test.py +++ b/tibert/run_test.py @@ -15,7 +15,7 @@ BertForCoreferenceResolution, CamembertForCoreferenceResolution, ) -from tibert.score import score_coref_predictions +from tibert.score import score_coref_predictions, score_mention_detection from tibert.predict import predict_coref from tibert.utils import split_coreference_document_tokens @@ -82,7 +82,7 @@ def main( all_annotated_docs = [] for document in tqdm(test_dataset.documents): doc_dataset = CoreferenceDataset( - split_coreference_document_tokens(document, 512), + [document], tokenizer, max_span_size, ) @@ -110,6 +110,17 @@ def main( annotated_doc = CoreferenceDocument.concatenated(annotated_docs) all_annotated_docs.append(annotated_doc) + mention_pre, mention_rec, mention_f1 = score_mention_detection( + all_annotated_docs, test_dataset.documents + ) + for metric_key, score in [ + ("precision", mention_pre), + ("recall", mention_rec), + ("f1", mention_f1), + ]: + print(f"mention.{metric_key}={score}") + _run.log_scalar(f"mention.{metric_key}", score) + scores = score_coref_predictions(all_annotated_docs, test_dataset.documents) for key, score_dict in scores.items(): for metric_key, score in score_dict.items(): diff --git a/tibert/run_train.py b/tibert/run_train.py index 10c1e0c..d6ea3a3 100644 --- a/tibert/run_train.py +++ b/tibert/run_train.py @@ -14,6 +14,7 @@ load_train_checkpoint, predict_coref, score_coref_predictions, + score_mention_detection, ) from tibert.bertcoref import CoreferenceDataset, load_democrat_dataset @@ -145,6 +146,17 @@ def main( ) assert isinstance(annotated_docs, list) + mention_pre, mention_rec, mention_f1 = score_mention_detection( + annotated_docs, test_dataset.documents + ) + for metric_key, score in [ + ("precision", mention_pre), + ("recall", mention_rec), + ("f1", mention_f1), + ]: + print(f"mention.{metric_key}={score}") + _run.log_scalar(f"mention.{metric_key}", score) + metrics = score_coref_predictions(annotated_docs, test_dataset.documents) print(metrics) for key, score_dict in metrics.items(): From 3efc7ef84691847b80749bb02fa1adb2a81347b3 Mon Sep 17 00:00:00 2001 From: Aethor Date: Thu, 29 Aug 2024 17:43:23 +0200 Subject: [PATCH 05/13] tibert.run_test now correctly support multiple document split configurations --- tibert/run_test.py | 66 +++++++++++++++++++++++++++------------------- 1 file changed, 39 insertions(+), 27 deletions(-) diff --git a/tibert/run_test.py b/tibert/run_test.py index 904ef57..09f9d24 100644 --- a/tibert/run_test.py +++ b/tibert/run_test.py @@ -1,4 +1,4 @@ -from typing import Literal +from typing import Literal, Optional import os import functools as ft from transformers import BertTokenizerFast, CamembertTokenizerFast # type: ignore @@ -6,6 +6,7 @@ from sacred.experiment import Experiment from sacred.run import Run from sacred.commands import print_config +from tibert import predict from tibert.bertcoref import ( CoreferenceDataset, CoreferenceDocument, @@ -29,6 +30,7 @@ def config(): dataset_name: str = "litbank" dataset_path: str = os.path.expanduser("~/litbank") max_span_size: int = 10 + limit_doc_size: Optional[int] = None hierarchical_merging: bool = False device_str: str = "auto" model_path: str @@ -41,6 +43,7 @@ def main( dataset_name: Literal["litbank", "fr-litbank", "democrat"], dataset_path: str, max_span_size: int, + limit_doc_size: Optional[int], hierarchical_merging: bool, device_str: Literal["cuda", "cpu", "auto"], model_path: str, @@ -79,36 +82,45 @@ def main( ) _, test_dataset = dataset.splitted(0.9) - all_annotated_docs = [] - for document in tqdm(test_dataset.documents): - doc_dataset = CoreferenceDataset( - [document], + if limit_doc_size is None: + all_annotated_docs = predict_coref( + [doc.tokens for doc in dataset.documents], + model, tokenizer, - max_span_size, + device_str=device_str, + batch_size=batch_size, ) - if hierarchical_merging: - annotated_doc = predict_coref( - [doc.tokens for doc in doc_dataset.documents], - model, + assert isinstance(all_annotated_docs, list) + else: + all_annotated_docs = [] + for document in tqdm(test_dataset.documents): + doc_dataset = CoreferenceDataset( + split_coreference_document_tokens(document, 512), tokenizer, - hierarchical_merging=True, - quiet=True, - device_str=device_str, - batch_size=batch_size, + max_span_size, ) - else: - annotated_docs = predict_coref( - [doc.tokens for doc in doc_dataset.documents], - model, - tokenizer, - hierarchical_merging=False, - quiet=True, - device_str=device_str, - batch_size=batch_size, - ) - assert isinstance(annotated_docs, list) - annotated_doc = CoreferenceDocument.concatenated(annotated_docs) - all_annotated_docs.append(annotated_doc) + if hierarchical_merging: + annotated_doc = predict_coref( + [doc.tokens for doc in doc_dataset.documents], + model, + tokenizer, + hierarchical_merging=True, + quiet=True, + device_str=device_str, + batch_size=batch_size, + ) + else: + annotated_docs = predict_coref( + [doc.tokens for doc in doc_dataset.documents], + model, + tokenizer, + quiet=True, + device_str=device_str, + batch_size=batch_size, + ) + assert isinstance(annotated_docs, list) + annotated_doc = CoreferenceDocument.concatenated(annotated_docs) + all_annotated_docs.append(annotated_doc) mention_pre, mention_rec, mention_f1 = score_mention_detection( all_annotated_docs, test_dataset.documents From 9be19414112a7969a9c8b92ed6f67d2af898988d Mon Sep 17 00:00:00 2001 From: Aethor Date: Thu, 29 Aug 2024 17:44:15 +0200 Subject: [PATCH 06/13] correctly use limit_doc_size argument for tibert.run_test --- tibert/run_test.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tibert/run_test.py b/tibert/run_test.py index 09f9d24..e0f465e 100644 --- a/tibert/run_test.py +++ b/tibert/run_test.py @@ -30,6 +30,7 @@ def config(): dataset_name: str = "litbank" dataset_path: str = os.path.expanduser("~/litbank") max_span_size: int = 10 + # in tokens limit_doc_size: Optional[int] = None hierarchical_merging: bool = False device_str: str = "auto" @@ -95,7 +96,7 @@ def main( all_annotated_docs = [] for document in tqdm(test_dataset.documents): doc_dataset = CoreferenceDataset( - split_coreference_document_tokens(document, 512), + split_coreference_document_tokens(document, limit_doc_size), tokenizer, max_span_size, ) From 3c158f9db774d1160cc735decc237d3f9c06af76 Mon Sep 17 00:00:00 2001 From: Aethor Date: Thu, 29 Aug 2024 17:54:30 +0200 Subject: [PATCH 07/13] correctly target test_dataset in run_test.py --- tibert/run_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tibert/run_test.py b/tibert/run_test.py index e0f465e..0d92ef7 100644 --- a/tibert/run_test.py +++ b/tibert/run_test.py @@ -85,7 +85,7 @@ def main( if limit_doc_size is None: all_annotated_docs = predict_coref( - [doc.tokens for doc in dataset.documents], + [doc.tokens for doc in test_dataset.documents], model, tokenizer, device_str=device_str, From 7ee36ee2e6879e920acb812a869e70ea7d961606 Mon Sep 17 00:00:00 2001 From: Aethor Date: Sun, 15 Sep 2024 16:22:27 +0200 Subject: [PATCH 08/13] fix a bug in mention pruning --- tibert/bertcoref.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/tibert/bertcoref.py b/tibert/bertcoref.py index a92ac19..b43322d 100644 --- a/tibert/bertcoref.py +++ b/tibert/bertcoref.py @@ -1093,9 +1093,9 @@ def pruned_mentions_indexs( The algorithm works as follows : 1. Sort mentions by individual scores - 2. Accept mention in orders, from best to worst score, until k of + 2. Accept mention in order, from best to worst score, until k of them are accepted. A mention can only be accepted if no - previously accepted span os overlapping with it. + previously accepted span is overlapping with it. See section 5 of the E2ECoref paper and the C++ kernel in the E2ECoref repository. @@ -1119,9 +1119,7 @@ def pruned_mentions_indexs( def spans_are_overlapping( span1: Tuple[int, int], span2: Tuple[int, int] ) -> bool: - return ( - span1[0] < span2[0] and span2[0] <= span1[1] and span1[1] < span2[1] - ) or (span2[0] < span1[0] and span1[0] <= span2[1] and span2[1] < span1[1]) + return not (span1[1] <= span2[0] or span1[0] >= span2[1]) _, sorted_indexs = torch.sort(mention_scores, 1, descending=True) # TODO: what if we can't have top_mentions_nb mentions ?? From bd6ebaf83e3bf4d026b446f602e822ec80cc7972 Mon Sep 17 00:00:00 2001 From: Aethor Date: Sun, 15 Sep 2024 18:05:05 +0200 Subject: [PATCH 09/13] ACTUALLY fix the span pruning issue --- tibert/bertcoref.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/tibert/bertcoref.py b/tibert/bertcoref.py index b43322d..cb38958 100644 --- a/tibert/bertcoref.py +++ b/tibert/bertcoref.py @@ -1088,14 +1088,14 @@ def mention_compatibility_score( def pruned_mentions_indexs( self, mention_scores: torch.Tensor, words_nb: int, top_mentions_nb: int ) -> torch.Tensor: - """Prune mentions, keeping only the k non-overlapping best of them + """Prune mentions, keeping only the k non-crossing best of them The algorithm works as follows : 1. Sort mentions by individual scores 2. Accept mention in order, from best to worst score, until k of them are accepted. A mention can only be accepted if no - previously accepted span is overlapping with it. + previously accepted span is crossing with it. See section 5 of the E2ECoref paper and the C++ kernel in the E2ECoref repository. @@ -1116,10 +1116,12 @@ def pruned_mentions_indexs( spans_idx = spans_indexs(list(range(words_nb)), self.config.max_span_size) - def spans_are_overlapping( - span1: Tuple[int, int], span2: Tuple[int, int] - ) -> bool: - return not (span1[1] <= span2[0] or span1[0] >= span2[1]) + def spans_are_crossing(span1: Tuple[int, int], span2: Tuple[int, int]) -> bool: + start1, end1 = (span1[0], span1[1] - 1) + start2, end2 = (span2[0], span2[1] - 1) + return (start1 < start2 and start2 <= end1 and end1 < end2) or ( + start2 < start1 and start1 <= end2 and end2 < end1 + ) _, sorted_indexs = torch.sort(mention_scores, 1, descending=True) # TODO: what if we can't have top_mentions_nb mentions ?? @@ -1134,7 +1136,7 @@ def spans_are_overlapping( span_index = int(sorted_indexs[b_i][s_j].item()) if not any( [ - spans_are_overlapping( + spans_are_crossing( spans_idx[span_index], spans_idx[mention_idx] ) for mention_idx in mention_indexs[-1] From bccb36ededb3b63711504293312efaf0d70de615 Mon Sep 17 00:00:00 2001 From: Aethor Date: Sun, 15 Sep 2024 22:43:13 +0200 Subject: [PATCH 10/13] closest_antecedents_indexs now only select in non-pruned mentions --- tibert/bertcoref.py | 37 ++++++++++++++++++++++++++++--------- 1 file changed, 28 insertions(+), 9 deletions(-) diff --git a/tibert/bertcoref.py b/tibert/bertcoref.py index cb38958..2d2fc62 100644 --- a/tibert/bertcoref.py +++ b/tibert/bertcoref.py @@ -1186,27 +1186,47 @@ def distance_between_spans(self, spans_nb: int, words_nb: int) -> torch.Tensor: return dist def closest_antecedents_indexs( - self, spans_nb: int, words_nb: int, antecedents_nb: int + self, + top_mentions_index: torch.Tensor, + spans_nb: int, + words_nb: int, + antecedents_nb: int, ): """Compute the indexs of the k closest mentions + :param top_mentions_index: a tensor of shape ``(b, m)`` :param spans_nb: number of spans in the sequence :param words_nb: number of words in the sequence :param antecedents_nb: number of antecedents to consider - :return: a tensor of shape ``(p, a)`` + :return: a tensor of shape ``(b, p, a)`` """ + device = next(self.parameters()).device + b, _ = top_mentions_index.shape + p = spans_nb + a = antecedents_nb + dist = self.distance_between_spans(spans_nb, words_nb) - assert dist.shape == (spans_nb, spans_nb) + assert dist.shape == (p, p) # when the distance between a span and a possible antecedent - # is 0 or negative, it means the possible antecedents is after - # the span. Therefore, it can't be an antecedents. We set - # those distances to Inf for torch.topk usage just after + # is 0 or negative, it means the possible antecedent is after + # the span. Therefore, it can't be an antecedent. We set those + # distances to Inf for torch.topk usage just after dist[dist <= 0] = float("Inf") + # discard pruned non-top mentions using the same technique as + # above + all_indices = torch.tile(torch.arange(spans_nb), (b, 1)).to(device) + pruned_mask = ~torch.isin(all_indices, top_mentions_index) + assert pruned_mask.shape == (b, p) + dist = torch.tile(dist, (b, 1, 1)) + dist[pruned_mask, :] = float("Inf") # remove pruned lines + dist.swapaxes(1, 2)[pruned_mask, :] = float("Inf") # remove pruned cols + assert dist.shape == (b, p, p) + # top-k closest antecedents _, close_indexs = torch.topk(-dist, antecedents_nb) - assert close_indexs.shape == (spans_nb, antecedents_nb) + assert close_indexs.shape == (b, p, a) return close_indexs @@ -1455,9 +1475,8 @@ def forward( # antecedents for each spans antecedents_nb = a = min(self.config.antecedents_nb, spans_nb) antecedents_index = self.closest_antecedents_indexs( - spans_nb, words_nb, antecedents_nb + top_mentions_index, spans_nb, words_nb, antecedents_nb ) - antecedents_index = torch.tile(antecedents_index, (batch_size, 1, 1)) assert antecedents_index.shape == (b, p, a) # -- mention compatibility scores computation -- From d8f19f7924bc2b663d7cedaf4238dcb744d21646 Mon Sep 17 00:00:00 2001 From: Aethor Date: Tue, 17 Sep 2024 16:15:44 +0200 Subject: [PATCH 11/13] update README with latest benchmark --- README.md | 25 ++++++------------------- 1 file changed, 6 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index 71de55c..fda4b86 100644 --- a/README.md +++ b/README.md @@ -118,8 +118,6 @@ This results in: Even if the mentions `Princess Liana` and `She` are not in the same chunk, hierarchical merging still resolves this case correctly. -*Note that, at the time of writing, the performance of the hierarchical merging feature has not been benchmarked*. - ## Training a model @@ -174,24 +172,13 @@ Several work make use of additional features. For now, only the distance between # Results -The following table presents the results we obtained by training this model (for now, it has only one entry !). Note that: - -- the reported results use `max_span_size=5` instead of `max_span_size=10` as in training. -- the reported results were obtained by splitting documents for performance reasons, with subdocuments having a maximum length of 11 sentences. They may not be accurate with the performance on full documents. -- the reported results can not be directly compared to the performance in [the original Litbank paper](https://arxiv.org/abs/1912.01140) since we only compute performance on one split of the datas - -| Dataset | Base model | MUC | B3 | CEAF | CoNLL F1 | -|---------|-------------------|-------|-------|-------|----------| -| Litbank | `bert-base-cased` | 77.35 | 67.63 | 56.66 | 67.21 | - -## Results on full documents - -The following table reports our results on the full Litbank documents (~2000 tokens each). We use `max_span_size=10`. HM stand for "Hierarchical Merging": +The following table presents the results we obtained on Litbank by training this model. We evaluate on 10% of Litbank documents, each of which consists of ~2000 tokens. The *split* column indicate whether documents were split in blocks of 512 tokens. The *HM* coumns indicates whether we use hierarchical merging. -| Dataset | Base model | HM | MUC | B3 | CEAF | BLANC | LEA | -|---------|-------------------|-----|-------|-------|-------|-------|-------| -| Litbank | `bert-base-cased` | no | 72.97 | 48.26 | 46.64 | 47.16 | 27.33 | -| Litbank | `bert-base-cased` | yes | 72.29 | 51.73 | 46.36 | 55.67 | 35.14 | +| Dataset | Base model | split | HM | MUC | B3 | CEAF | BLANC | LEA | time (s) | +|---------|-------------------|-------|-----|-------|-------|-------|-------|-------|----------| +| Litbank | `bert-base-cased` | no | no | 75.03 | 60.66 | 48.71 | 62.96 | 32.84 | 22:07 | +| Litbank | `bert-base-cased` | yes | no | 73.84 | 49.14 | 47.88 | 48.41 | 27.63 | 16:18 | +| Litbank | `bert-base-cased` | yes | yes | 74.54 | 59.30 | 46.98 | 62.69 | 42.46 | 21:13 | # Citation From 97329fba5525cfc1ad29edc857deeb3b23b1af26 Mon Sep 17 00:00:00 2001 From: Aethor Date: Tue, 17 Sep 2024 16:16:03 +0200 Subject: [PATCH 12/13] update README benchmark columns --- README.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index fda4b86..59b047f 100644 --- a/README.md +++ b/README.md @@ -174,11 +174,11 @@ Several work make use of additional features. For now, only the distance between The following table presents the results we obtained on Litbank by training this model. We evaluate on 10% of Litbank documents, each of which consists of ~2000 tokens. The *split* column indicate whether documents were split in blocks of 512 tokens. The *HM* coumns indicates whether we use hierarchical merging. -| Dataset | Base model | split | HM | MUC | B3 | CEAF | BLANC | LEA | time (s) | -|---------|-------------------|-------|-----|-------|-------|-------|-------|-------|----------| -| Litbank | `bert-base-cased` | no | no | 75.03 | 60.66 | 48.71 | 62.96 | 32.84 | 22:07 | -| Litbank | `bert-base-cased` | yes | no | 73.84 | 49.14 | 47.88 | 48.41 | 27.63 | 16:18 | -| Litbank | `bert-base-cased` | yes | yes | 74.54 | 59.30 | 46.98 | 62.69 | 42.46 | 21:13 | +| Dataset | Base model | split | HM | MUC | B3 | CEAF | BLANC | LEA | time (m:s) | +|---------|-------------------|-------|-----|-------|-------|-------|-------|-------|------------| +| Litbank | `bert-base-cased` | no | no | 75.03 | 60.66 | 48.71 | 62.96 | 32.84 | 22:07 | +| Litbank | `bert-base-cased` | yes | no | 73.84 | 49.14 | 47.88 | 48.41 | 27.63 | 16:18 | +| Litbank | `bert-base-cased` | yes | yes | 74.54 | 59.30 | 46.98 | 62.69 | 42.46 | 21:13 | # Citation From 148e5648d6a0cc0e3b85311023c21c396462b4cd Mon Sep 17 00:00:00 2001 From: Aethor Date: Tue, 17 Sep 2024 16:25:02 +0200 Subject: [PATCH 13/13] inhibit torch.index_reduce_ warning --- tibert/bertcoref.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/tibert/bertcoref.py b/tibert/bertcoref.py index 2d2fc62..893c8eb 100644 --- a/tibert/bertcoref.py +++ b/tibert/bertcoref.py @@ -10,7 +10,7 @@ TypeVar, Union, ) -import re, glob, os +import re, glob, os, warnings from collections import defaultdict from pathlib import Path from dataclasses import dataclass @@ -332,9 +332,9 @@ def torch_call(self, features) -> Union[dict, BatchEncoding]: # same length yet. return_tensors=None, ) - self.tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = ( - warning_state - ) + self.tokenizer.deprecation_warnings[ + "Asking-to-pad-a-fast-tokenizer" + ] = warning_state # keep encoding info batch._encodings = [f.encodings[0] for f in features] @@ -1349,9 +1349,11 @@ def wordreduce_embeddings( batch_word_ids = batch_word_ids[token_mask] words_nb = len(set(batch_word_ids.tolist())) words = torch.zeros(words_nb, h, device=device) - words.index_reduce_( - 0, batch_word_ids, batch_encoded, "mean", include_self=False - ) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + words.index_reduce_( + 0, batch_word_ids, batch_encoded, "mean", include_self=False + ) word_encoded.append(words) # each 2D-tensor in word_encoded is of shape (w_i, h). w_i