diff --git a/README.md b/README.md index 71de55c..59b047f 100644 --- a/README.md +++ b/README.md @@ -118,8 +118,6 @@ This results in: Even if the mentions `Princess Liana` and `She` are not in the same chunk, hierarchical merging still resolves this case correctly. -*Note that, at the time of writing, the performance of the hierarchical merging feature has not been benchmarked*. - ## Training a model @@ -174,24 +172,13 @@ Several work make use of additional features. For now, only the distance between # Results -The following table presents the results we obtained by training this model (for now, it has only one entry !). Note that: - -- the reported results use `max_span_size=5` instead of `max_span_size=10` as in training. -- the reported results were obtained by splitting documents for performance reasons, with subdocuments having a maximum length of 11 sentences. They may not be accurate with the performance on full documents. -- the reported results can not be directly compared to the performance in [the original Litbank paper](https://arxiv.org/abs/1912.01140) since we only compute performance on one split of the datas - -| Dataset | Base model | MUC | B3 | CEAF | CoNLL F1 | -|---------|-------------------|-------|-------|-------|----------| -| Litbank | `bert-base-cased` | 77.35 | 67.63 | 56.66 | 67.21 | - -## Results on full documents - -The following table reports our results on the full Litbank documents (~2000 tokens each). We use `max_span_size=10`. HM stand for "Hierarchical Merging": +The following table presents the results we obtained on Litbank by training this model. We evaluate on 10% of Litbank documents, each of which consists of ~2000 tokens. The *split* column indicate whether documents were split in blocks of 512 tokens. The *HM* coumns indicates whether we use hierarchical merging. -| Dataset | Base model | HM | MUC | B3 | CEAF | BLANC | LEA | -|---------|-------------------|-----|-------|-------|-------|-------|-------| -| Litbank | `bert-base-cased` | no | 72.97 | 48.26 | 46.64 | 47.16 | 27.33 | -| Litbank | `bert-base-cased` | yes | 72.29 | 51.73 | 46.36 | 55.67 | 35.14 | +| Dataset | Base model | split | HM | MUC | B3 | CEAF | BLANC | LEA | time (m:s) | +|---------|-------------------|-------|-----|-------|-------|-------|-------|-------|------------| +| Litbank | `bert-base-cased` | no | no | 75.03 | 60.66 | 48.71 | 62.96 | 32.84 | 22:07 | +| Litbank | `bert-base-cased` | yes | no | 73.84 | 49.14 | 47.88 | 48.41 | 27.63 | 16:18 | +| Litbank | `bert-base-cased` | yes | yes | 74.54 | 59.30 | 46.98 | 62.69 | 42.46 | 21:13 | # Citation diff --git a/tibert/bertcoref.py b/tibert/bertcoref.py index 77bf456..893c8eb 100644 --- a/tibert/bertcoref.py +++ b/tibert/bertcoref.py @@ -10,18 +10,16 @@ TypeVar, Union, ) -import re, glob, os +import re, glob, os, warnings from collections import defaultdict from pathlib import Path from dataclasses import dataclass from sacremoses import MosesTokenizer from more_itertools.recipes import flatten -import numpy as np import torch from torch.nn.parameter import Parameter from torch.utils.data import Dataset -from torch.utils.data.dataloader import DataLoader -from transformers import PreTrainedModel, BertPreTrainedModel # type: ignore +from transformers import BertPreTrainedModel # type: ignore from transformers import PreTrainedTokenizerFast # type: ignore from transformers.file_utils import PaddingStrategy from transformers.data.data_collator import DataCollatorMixin @@ -157,7 +155,10 @@ def document_labels(self, max_span_size: int) -> Tuple[torch.Tensor, torch.Tenso return (self.coref_labels(max_span_size), self.mention_labels(max_span_size)) def prepared_document( - self, tokenizer: PreTrainedTokenizerFast, max_span_size: int + self, + tokenizer: PreTrainedTokenizerFast, + max_span_size: int, + mode: Literal["train", "test"] = "train", ) -> Tuple[CoreferenceDocument, BatchEncoding]: """Prepare a document for being inputted into a model. The document is retokenized thanks to ``tokenizer``, and @@ -169,6 +170,10 @@ def prepared_document( method. This means special tokens will be added. :param tokenizer: tokenizer used to retokenized the document + :param max_span_size: + :param mode: if 'test', dont include labels in the returned + :class:`BatchEncoding` + :return: a tuple : - a new :class:`CoreferenceDocument` @@ -222,64 +227,14 @@ def prepared_document( new_chains.append(new_chain) document = CoreferenceDocument(tokens, new_chains) - coref_labels, mention_labels = document.document_labels(max_span_size) - batch["coref_labels"] = coref_labels - batch["mention_labels"] = mention_labels + if mode == "train": + coref_labels, mention_labels = self.document_labels(max_span_size) + batch["coref_labels"] = coref_labels + batch["mention_labels"] = mention_labels + batch["word_ids"] = torch.tensor([-1 if i is None else i for i in words_ids]) return document, batch - def from_wpieced_to_tokenized( - self, tokens: List[str], wp_to_token: List[int] - ) -> CoreferenceDocument: - """Convert the current document, tokenized with wordpieces, to - a document 'normally' tokenized - - :param tokens: the original tokens - :param wp_to_token: mapping from wordpiece index to token index - - :return: - """ - # In some cases, output mentions can be produced several - # times. This can happen when a mention boundary is expressed - # by several wordpiece (for example the wordpiece mentions : - # "l ' abbé" and "l '" can both correspond to the regular - # mention "l 'abbé". For those cases, we keep only one choice - # and assume that predictions for these mentions are - # consistent. - already_visited_mentions = set() - - new_chains = [] - - for chain in self.coref_chains: - new_chain = [] - - for mention in chain: - new_start_idx = wp_to_token[mention.start_idx] - new_end_idx = wp_to_token[mention.end_idx - 1] - # NOTE: this happens in case the model has predicted - # an erroneous mention such as '[CLS]' or '[SEP]'. In - # that case, we simply ignore the mention. - if new_start_idx is None or new_end_idx is None: - continue - new_end_idx += 1 - - new_mention = Mention( - tokens[new_start_idx:new_end_idx], - new_start_idx, - new_end_idx, - ) - - if new_mention in already_visited_mentions: - continue - - new_chain.append(new_mention) - - already_visited_mentions.add(new_mention) - - new_chains.append(new_chain) - - return CoreferenceDocument(tokens, new_chains) - @staticmethod def from_labels( tokens: List[str], @@ -365,20 +320,6 @@ class DataCollatorForSpanClassification(DataCollatorMixin): return_tensors: Literal["pt"] = "pt" def torch_call(self, features) -> Union[dict, BatchEncoding]: - coref_labels = ( - [feature["coref_labels"] for feature in features] - if "coref_labels" in features[0].keys() - else None - ) - mention_labels = ( - [feature["mention_labels"] for feature in features] - if "mention_labels" in features[0].keys() - else None - ) - assert (coref_labels is None and mention_labels is None) or ( - coref_labels and mention_labels - ) - warning_state = self.tokenizer.deprecation_warnings.get( "Asking-to-pad-a-fast-tokenizer", False ) @@ -387,9 +328,9 @@ def torch_call(self, features) -> Union[dict, BatchEncoding]: features, padding=self.padding, max_length=self.max_length, - # Conversion to tensors will fail if we have labels as - # they are not of the same length yet. - return_tensors="pt" if coref_labels is None else None, + # Conversion to tensors will fail as they are not of the + # same length yet. + return_tensors=None, ) self.tokenizer.deprecation_warnings[ "Asking-to-pad-a-fast-tokenizer" @@ -398,40 +339,48 @@ def torch_call(self, features) -> Union[dict, BatchEncoding]: # keep encoding info batch._encodings = [f.encodings[0] for f in features] - if coref_labels is None: - return batch - - documents = [ - CoreferenceDocument.from_labels( - tokens, coref_labels, mention_labels, max_span_size=self.max_span_size - ) - for tokens, coref_labels, mention_labels in zip( - [f["input_ids"] for f in features], - [f["coref_labels"] for f in features], - [f["mention_labels"] for f in features], - ) - ] + device = torch.device(self.device) - for document, tokens in zip(documents, batch["input_ids"]): # type: ignore - document.tokens = tokens - labels = [doc.document_labels(self.max_span_size) for doc in documents] + if "coref_labels" in batch: + # TODO: if we can find a better way to pad sparse tensors, + # that would be great... At least the big tensor never finds + # its way onto the GPU. + max_p = max([tens.shape[0] for tens in batch["coref_labels"]]) + for tens_i, tens in enumerate(batch["coref_labels"]): + p = tens.shape[0] + if max_p > p: + tens = torch.cat([tens.to_dense(), torch.zeros(max_p - p, p + 1)]) + batch["coref_labels"][tens_i] = torch.cat( + [ + tens[:, :-1], + torch.zeros(max_p, max_p - p), + tens[:, -1].unsqueeze(1), + ], + dim=1, + ).to_sparse_coo() + batch["coref_labels"] = torch.stack( + [tens for tens in batch["coref_labels"]] + ).to(device) + + batch["mention_labels"] = torch.nn.utils.rnn.pad_sequence( + batch["mention_labels"], batch_first=True + ).to(device) + + batch["word_ids"] = torch.nn.utils.rnn.pad_sequence( + batch["word_ids"], batch_first=True, padding_value=-1 + ).to(device) - device = torch.device(self.device) - del batch["coref_labels"] - del batch["mention_labels"] batch = BatchEncoding( { - k: torch.tensor(v, dtype=torch.int64, device=device) + k: ( + torch.tensor(v, dtype=torch.int64, device=device) + if not k in ("coref_labels", "mention_labels", "word_ids") + else v + ) for k, v in batch.items() }, encoding=batch.encodings, ) - batch["coref_labels"] = torch.stack( - [coref_labels for coref_labels, _ in labels] - ).to(device) - batch["mention_labels"] = torch.stack( - [mention_labels for _, mention_labels in labels] - ).to(device) return batch @@ -441,6 +390,7 @@ class CoreferenceDataset(Dataset): :ivar documents: :ivar tokenizer: :ivar max_span_len: + :ivar mode: """ def __init__( @@ -453,6 +403,13 @@ def __init__( self.documents = documents self.tokenizer = tokenizer self.max_span_size = max_span_size + self.mode: Literal["train", "test"] = "train" + + def set_test_(self): + self.mode = "test" + + def set_train_(self): + self.mode = "train" @staticmethod def from_conll2012_file( @@ -676,7 +633,9 @@ def __len__(self) -> int: def __getitem__(self, index: int) -> BatchEncoding: document = self.documents[index] - _, batch = document.prepared_document(self.tokenizer, self.max_span_size) + _, batch = document.prepared_document( + self.tokenizer, self.max_span_size, self.mode + ) return batch @@ -850,23 +809,39 @@ def load_ontonotes_dataset( @dataclass class BertCoreferenceResolutionOutput: - # (batch_size, top_mentions_nb, antecedents_nb) + """Output of BertForCoreferenceResolution + + .. note :: + + We use the following short notation to annotate shapes : + + - b: batch_size + - s: seq_size (in wordpieces) + - w: seq_size_words (in words) + - p: spans_nb + - m: top_mentions_nb + - a: antecedents_nb + - h: hidden_size + - t: metadatas_features_size + """ + + # (b, m, a) logits: torch.Tensor - # (batch_size, top_mentions_nb) + # (b, m) top_mentions_index: torch.Tensor - # (batch_size, spans_nb) + # (b, p) mentions_scores: torch.Tensor - # (batch_size, top_mentions_nb, antecedents_nb) + # (b, m, a) top_antecedents_index: torch.Tensor max_span_size: int loss: Optional[torch.Tensor] = None - # (batch_size, seq_size, hidden_size) + # (b, w, h) hidden_states: Optional[torch.FloatTensor] = None def coreference_documents( @@ -875,7 +850,7 @@ def coreference_documents( """Extract a :class:`.CoreferenceDocument` list from a coreference output. - :param tokens: + :param tokens: the original tokens :return: a list of :class:`.CoreferenceDocument`, one per batch @@ -897,7 +872,14 @@ def coreference_documents( G = nx.Graph() for m_j in range(top_mentions_nb): span_i = int(self.top_mentions_index[b_i][m_j].item()) - span_coords = spans_idx[span_i] + # it is possible to have a top span that does not + # actually exist in a batch sample. This is because + # padding is done on wordpieces but not on words. In + # that case, we simply ignore that predicted span. + try: + span_coords = spans_idx[span_i] + except IndexError: + continue mention_score = float(self.mentions_scores[b_i][span_i].item()) span_mention = Mention( @@ -920,7 +902,10 @@ def coreference_documents( antecedent_span_i = int( self.top_antecedents_index[b_i][m_j][top_antecedent_idx].item() ) - antecedent_coords = spans_idx[antecedent_span_i] + try: + antecedent_coords = spans_idx[antecedent_span_i] + except IndexError: + continue antecedent_mention_score = float( self.mentions_scores[b_i][antecedent_span_i].item() @@ -977,7 +962,8 @@ class BertForCoreferenceResolution(BertPreTrainedModel): We use the following short notation to annotate shapes : - b: batch_size - - s: seq_size + - s: seq_size (in wordpieces) + - w: seq_size_words (in words) - p: spans_nb - m: top_mentions_nb - a: antecedents_nb @@ -1100,23 +1086,23 @@ def mention_compatibility_score( return self.mention_compatibility_scorer(score).squeeze(-1) def pruned_mentions_indexs( - self, mention_scores: torch.Tensor, seq_size: int, top_mentions_nb: int + self, mention_scores: torch.Tensor, words_nb: int, top_mentions_nb: int ) -> torch.Tensor: - """Prune mentions, keeping only the k non-overlapping best of them + """Prune mentions, keeping only the k non-crossing best of them The algorithm works as follows : 1. Sort mentions by individual scores - 2. Accept mention in orders, from best to worst score, until k of + 2. Accept mention in order, from best to worst score, until k of them are accepted. A mention can only be accepted if no - previously accepted span os overlapping with it. + previously accepted span is crossing with it. See section 5 of the E2ECoref paper and the C++ kernel in the E2ECoref repository. :param mention_scores: a tensor of shape ``(b, p)`` - :param seq_size: + :param words_nb: :param top_mentions_nb: the maximum number of spans to keep during the pruning process @@ -1128,14 +1114,14 @@ def pruned_mentions_indexs( assert top_mentions_nb <= spans_nb - spans_idx = spans_indexs(list(range(seq_size)), self.config.max_span_size) + spans_idx = spans_indexs(list(range(words_nb)), self.config.max_span_size) - def spans_are_overlapping( - span1: Tuple[int, int], span2: Tuple[int, int] - ) -> bool: - return ( - span1[0] < span2[0] and span2[0] <= span1[1] and span1[1] < span2[1] - ) or (span2[0] < span1[0] and span1[0] <= span2[1] and span2[1] < span1[1]) + def spans_are_crossing(span1: Tuple[int, int], span2: Tuple[int, int]) -> bool: + start1, end1 = (span1[0], span1[1] - 1) + start2, end2 = (span2[0], span2[1] - 1) + return (start1 < start2 and start2 <= end1 and end1 < end2) or ( + start2 < start1 and start1 <= end2 and end2 < end1 + ) _, sorted_indexs = torch.sort(mention_scores, 1, descending=True) # TODO: what if we can't have top_mentions_nb mentions ?? @@ -1150,7 +1136,7 @@ def spans_are_overlapping( span_index = int(sorted_indexs[b_i][s_j].item()) if not any( [ - spans_are_overlapping( + spans_are_crossing( spans_idx[span_index], spans_idx[mention_idx] ) for mention_idx in mention_indexs[-1] @@ -1168,11 +1154,11 @@ def spans_are_overlapping( return mention_indexs - def distance_between_spans(self, spans_nb: int, seq_size: int) -> torch.Tensor: + def distance_between_spans(self, spans_nb: int, words_nb: int) -> torch.Tensor: """Compute the indexs of the k closest mentions :param spans_nb: number of spans in the sequence - :param seq_size: size of the sequence + :param words_nb: number of words in the sequence :return: a tensor of shape ``(p, p)`` """ p = spans_nb @@ -1180,7 +1166,7 @@ def distance_between_spans(self, spans_nb: int, seq_size: int) -> torch.Tensor: # a list of spans indices # [(start, end), ..., (start, end)] - spans_idx = spans_indexs(list(range(seq_size)), self.config.max_span_size) + spans_idx = spans_indexs(list(range(words_nb)), self.config.max_span_size) # (spans_nb,) start_idx = torch.tensor([start for start, _ in spans_idx]).to(device) @@ -1200,27 +1186,47 @@ def distance_between_spans(self, spans_nb: int, seq_size: int) -> torch.Tensor: return dist def closest_antecedents_indexs( - self, spans_nb: int, seq_size: int, antecedents_nb: int + self, + top_mentions_index: torch.Tensor, + spans_nb: int, + words_nb: int, + antecedents_nb: int, ): """Compute the indexs of the k closest mentions + :param top_mentions_index: a tensor of shape ``(b, m)`` :param spans_nb: number of spans in the sequence - :param seq_size: size of the sequence + :param words_nb: number of words in the sequence :param antecedents_nb: number of antecedents to consider - :return: a tensor of shape ``(p, a)`` + :return: a tensor of shape ``(b, p, a)`` """ - dist = self.distance_between_spans(spans_nb, seq_size) - assert dist.shape == (spans_nb, spans_nb) + device = next(self.parameters()).device + b, _ = top_mentions_index.shape + p = spans_nb + a = antecedents_nb + + dist = self.distance_between_spans(spans_nb, words_nb) + assert dist.shape == (p, p) # when the distance between a span and a possible antecedent - # is 0 or negative, it means the possible antecedents is after - # the span. Therefore, it can't be an antecedents. We set - # those distances to Inf for torch.topk usage just after + # is 0 or negative, it means the possible antecedent is after + # the span. Therefore, it can't be an antecedent. We set those + # distances to Inf for torch.topk usage just after dist[dist <= 0] = float("Inf") + # discard pruned non-top mentions using the same technique as + # above + all_indices = torch.tile(torch.arange(spans_nb), (b, 1)).to(device) + pruned_mask = ~torch.isin(all_indices, top_mentions_index) + assert pruned_mask.shape == (b, p) + dist = torch.tile(dist, (b, 1, 1)) + dist[pruned_mask, :] = float("Inf") # remove pruned lines + dist.swapaxes(1, 2)[pruned_mask, :] = float("Inf") # remove pruned cols + assert dist.shape == (b, p, p) + # top-k closest antecedents _, close_indexs = torch.topk(-dist, antecedents_nb) - assert close_indexs.shape == (spans_nb, antecedents_nb) + assert close_indexs.shape == (b, p, a) return close_indexs @@ -1229,24 +1235,24 @@ def distance_feature( top_antecedents_index: torch.Tensor, top_mentions_index: torch.Tensor, spans_nb: int, - seq_size: int, + words_nb: int, ) -> torch.Tensor: """Compute the distance feature between two spans :param top_antecedents_index: ``(b, m, a)`` :param top_mentions_index: ``(b, m)`` :param spans_nb: - :param seq_size: + :param words_nb: number of words in the sequence :return: ``(b, m, a, t)`` """ b, m, a = top_antecedents_index.shape t = self.config.metadatas_features_size p = spans_nb - s = seq_size + w = words_nb device = next(self.parameters()).device - dist = self.distance_between_spans(p, s) + dist = self.distance_between_spans(p, w) dist = dist.unsqueeze(0).repeat(b, 1, 1) assert dist.shape == (b, p, p) @@ -1287,7 +1293,6 @@ def bert_encode( :return: hidden states of the last layer, of shape ``(b, s, h)`` """ - # list[(batch_size, <= segment_size, hidden_size)] last_hidden_states = [] @@ -1312,6 +1317,50 @@ def maybe_take_segment( return torch.cat(last_hidden_states, dim=1) + def wordreduce_embeddings( + self, encoded: torch.Tensor, word_ids: torch.LongTensor + ) -> torch.Tensor: + """ + :param encoded: an embedding tensor of shape ``(b, s, h)`` + :param word_ids: a tensor of shape ``(b, s)`` + :return: an embedding tensor of shape ``(b, w, h)`` + """ + device = next(self.parameters()).device + b, _, h = encoded.shape + + word_encoded = [] + for b_i in range(b): + # here is what we want to achieve: + # + # word_ids ~ [0 0 0 -1 1 1 2 2 ] + # encoded ~ [A B C D E F G H ] + # => [mean(ABC) mean(EF) mean(GH)] + # + # To do so we first filter encoded to remove -1s since + # they dont correspond to any tokens. Then, we use + # torch.index_reduce to achieve mean reduction. At the + # time of this writing, torch.index_reduce is in beta so + # it will print out a warning! Unfortunately there is no + # batched version of torch.index_reduce as far as I + # understand, so we do it in a loop :< + batch_word_ids = word_ids[b_i] + token_mask = batch_word_ids != -1 + batch_encoded = encoded[b_i][token_mask] + batch_word_ids = batch_word_ids[token_mask] + words_nb = len(set(batch_word_ids.tolist())) + words = torch.zeros(words_nb, h, device=device) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + words.index_reduce_( + 0, batch_word_ids, batch_encoded, "mean", include_self=False + ) + word_encoded.append(words) + + # each 2D-tensor in word_encoded is of shape (w_i, h). w_i + # varies depending on the batch, so we have to do some + # padding! + return torch.nn.utils.rnn.pad_sequence(word_encoded, batch_first=True) + def mention_loss( self, top_mention_scores: torch.Tensor, mention_labels: torch.Tensor ) -> torch.Tensor: @@ -1352,6 +1401,7 @@ def coref_loss( def forward( self, input_ids: torch.Tensor, + word_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, @@ -1362,6 +1412,7 @@ def forward( ) -> BertCoreferenceResolutionOutput: """ :param input_ids: a tensor of shape ``(b, s)`` + :param word_ids: a tensor of shape ``(b, s)`` :param attention_mask: a tensor of shape ``(b, s)`` :param token_type_ids: a tensor of shape ``(b, s)`` :param position_ids: a tensor of shape ``(b, s)`` @@ -1385,11 +1436,14 @@ def forward( head_mask=head_mask, ) assert encoded_input.shape == (b, s, h) + encoded_input = self.wordreduce_embeddings(encoded_input, word_ids) + words_nb = w = encoded_input.shape[1] + assert encoded_input.shape == (b, w, h) # -- span bounds computation -- # we select starting and ending bounds of spans of length up # to self.max_span_size - spans_idx = spans(range(seq_size), self.config.max_span_size) + spans_idx = spans(range(words_nb), self.config.max_span_size) spans_nb = p = len(spans_idx) spans_selector = torch.flatten( torch.tensor([[span[0], span[-1]] for span in spans_idx], dtype=torch.long) @@ -1409,9 +1463,11 @@ def forward( # top_mentions_index is the index of the m best # non-overlapping mentions - top_mentions_nb = m = int(self.config.mentions_per_tokens * seq_size) + top_mentions_nb = m = max( + min(1, words_nb), int(self.config.mentions_per_tokens * words_nb) + ) top_mentions_index = self.pruned_mentions_indexs( - mention_scores, seq_size, top_mentions_nb + mention_scores, words_nb, top_mentions_nb ) # TODO: hack top_mentions_nb = m = int(top_mentions_index.shape[1]) @@ -1421,9 +1477,8 @@ def forward( # antecedents for each spans antecedents_nb = a = min(self.config.antecedents_nb, spans_nb) antecedents_index = self.closest_antecedents_indexs( - spans_nb, seq_size, antecedents_nb + top_mentions_index, spans_nb, words_nb, antecedents_nb ) - antecedents_index = torch.tile(antecedents_index, (batch_size, 1, 1)) assert antecedents_index.shape == (b, p, a) # -- mention compatibility scores computation -- @@ -1449,7 +1504,7 @@ def forward( # distance feature computation dist_ft = self.distance_feature( - top_antecedents_index, top_mentions_index, spans_nb, seq_size + top_antecedents_index, top_mentions_index, spans_nb, words_nb ) dist_ft = torch.flatten(dist_ft, start_dim=1, end_dim=2) assert dist_ft.shape == (b, m * a, t) @@ -1502,7 +1557,7 @@ def forward( # NOTE: we have to rely on such a loop, as torch.gather # cannot be used on sparse tensors, which prevents using - # batch_index_select + # batch_index_select. selected_coref_labels = torch.stack( [ torch.index_select(coref_labels[b_i], 0, top_mentions_index[b_i]) diff --git a/tibert/predict.py b/tibert/predict.py index f59765a..34a3ec8 100644 --- a/tibert/predict.py +++ b/tibert/predict.py @@ -205,7 +205,7 @@ def tensorize_chains( return CoreferenceDocument(merged_left.tokens + merged_right.tokens, new_chains) -def _stream_predict_wpieced_coref_raw( +def _stream_predict_coref_raw( documents: List[Union[str, List[str]]], model: BertForCoreferenceResolution, tokenizer: PreTrainedTokenizerFast, @@ -215,14 +215,7 @@ def _stream_predict_wpieced_coref_raw( lang: str = "en", return_hidden_state: bool = False, ) -> Generator[ - Tuple[ - List[CoreferenceDocument], - List[CoreferenceDocument], - BatchEncoding, - BertCoreferenceResolutionOutput, - ], - None, - None, + Tuple[List[CoreferenceDocument], BertCoreferenceResolutionOutput], None, None ]: """Low level inference interface.""" @@ -248,6 +241,7 @@ def _stream_predict_wpieced_coref_raw( tokenizer, model.config.max_span_size, ) + dataset.set_test_() data_collator = DataCollatorForSpanClassification( tokenizer, model.config.max_span_size, device_str ) @@ -271,14 +265,9 @@ def _stream_predict_wpieced_coref_raw( **batch, return_hidden_state=return_hidden_state ) - out_docs = out.coreference_documents( - [ - [tokenizer.decode(t) for t in input_ids] # type: ignore - for input_ids in batch["input_ids"] - ] - ) + out_docs = out.coreference_documents([doc.tokens for doc in batch_docs]) - yield batch_docs, out_docs, batch, out + yield out_docs, out def stream_predict_coref( @@ -303,16 +292,11 @@ def stream_predict_coref( :return: a list of ``CoreferenceDocument``, with annotated coreference chains. """ - for original_docs, out_docs, batch, out in _stream_predict_wpieced_coref_raw( + for out_docs, _ in _stream_predict_coref_raw( documents, model, tokenizer, batch_size, quiet, device_str, lang ): - for batch_i, (original_doc, out_doc) in enumerate(zip(original_docs, out_docs)): - seq_size = batch["input_ids"].shape[1] - wp_to_token = [ - batch.token_to_word(batch_i, token_index=i) for i in range(seq_size) - ] - doc = out_doc.from_wpieced_to_tokenized(original_doc.tokens, wp_to_token) - yield doc + for out_doc in out_docs: + yield out_doc def predict_coref( @@ -344,13 +328,11 @@ def predict_coref( if hierarchical_merging: docs = [] hidden_states = [] - all_tokens = [] - wp_to_token = [] if len(documents) == 0: return None - for original_docs, out_docs, batch, out in _stream_predict_wpieced_coref_raw( + for out_docs, out in _stream_predict_coref_raw( documents, model, tokenizer, @@ -365,28 +347,8 @@ def predict_coref( assert not out.hidden_states is None hidden_states += [h for h in out.hidden_states] - all_tokens += list(flatten(doc.tokens for doc in original_docs)) - - batch_size = batch["input_ids"].shape[0] - seq_size = batch["input_ids"].shape[1] - for batch_i in range(batch_size): - # we need to shift the index of tokens in the batch by - # the index of the last token of the previous batch - max_prev_wp_to_token = ( - 0 - if len(wp_to_token) == 0 - else max([wtt for wtt in wp_to_token if not wtt is None], default=0) - ) - for i in range(seq_size): - wtt = batch.token_to_word(batch_i, token_index=i) - if not wtt is None and max_prev_wp_to_token != 0: - wtt += max_prev_wp_to_token + 1 - wp_to_token.append(wtt) - - merged_doc_wpieced = merge_coref_outputs(docs, hidden_states, model, device_str) - assert not merged_doc_wpieced is None # we know that len(docs) > 0 - - return merged_doc_wpieced.from_wpieced_to_tokenized(all_tokens, wp_to_token) + merged_doc = merge_coref_outputs(docs, hidden_states, model, device_str) + return merged_doc return list( stream_predict_coref( diff --git a/tibert/run_test.py b/tibert/run_test.py index d888d7b..0d92ef7 100644 --- a/tibert/run_test.py +++ b/tibert/run_test.py @@ -1,4 +1,4 @@ -from typing import Literal +from typing import Literal, Optional import os import functools as ft from transformers import BertTokenizerFast, CamembertTokenizerFast # type: ignore @@ -6,6 +6,7 @@ from sacred.experiment import Experiment from sacred.run import Run from sacred.commands import print_config +from tibert import predict from tibert.bertcoref import ( CoreferenceDataset, CoreferenceDocument, @@ -15,7 +16,7 @@ BertForCoreferenceResolution, CamembertForCoreferenceResolution, ) -from tibert.score import score_coref_predictions +from tibert.score import score_coref_predictions, score_mention_detection from tibert.predict import predict_coref from tibert.utils import split_coreference_document_tokens @@ -29,6 +30,8 @@ def config(): dataset_name: str = "litbank" dataset_path: str = os.path.expanduser("~/litbank") max_span_size: int = 10 + # in tokens + limit_doc_size: Optional[int] = None hierarchical_merging: bool = False device_str: str = "auto" model_path: str @@ -41,6 +44,7 @@ def main( dataset_name: Literal["litbank", "fr-litbank", "democrat"], dataset_path: str, max_span_size: int, + limit_doc_size: Optional[int], hierarchical_merging: bool, device_str: Literal["cuda", "cpu", "auto"], model_path: str, @@ -79,36 +83,56 @@ def main( ) _, test_dataset = dataset.splitted(0.9) - all_annotated_docs = [] - for document in tqdm(test_dataset.documents): - doc_dataset = CoreferenceDataset( - split_coreference_document_tokens(document, 512), + if limit_doc_size is None: + all_annotated_docs = predict_coref( + [doc.tokens for doc in test_dataset.documents], + model, tokenizer, - max_span_size, + device_str=device_str, + batch_size=batch_size, ) - if hierarchical_merging: - annotated_doc = predict_coref( - [doc.tokens for doc in doc_dataset.documents], - model, + assert isinstance(all_annotated_docs, list) + else: + all_annotated_docs = [] + for document in tqdm(test_dataset.documents): + doc_dataset = CoreferenceDataset( + split_coreference_document_tokens(document, limit_doc_size), tokenizer, - hierarchical_merging=True, - quiet=True, - device_str=device_str, - batch_size=batch_size, + max_span_size, ) - else: - annotated_docs = predict_coref( - [doc.tokens for doc in doc_dataset.documents], - model, - tokenizer, - hierarchical_merging=False, - quiet=True, - device_str=device_str, - batch_size=batch_size, - ) - assert isinstance(annotated_docs, list) - annotated_doc = CoreferenceDocument.concatenated(annotated_docs) - all_annotated_docs.append(annotated_doc) + if hierarchical_merging: + annotated_doc = predict_coref( + [doc.tokens for doc in doc_dataset.documents], + model, + tokenizer, + hierarchical_merging=True, + quiet=True, + device_str=device_str, + batch_size=batch_size, + ) + else: + annotated_docs = predict_coref( + [doc.tokens for doc in doc_dataset.documents], + model, + tokenizer, + quiet=True, + device_str=device_str, + batch_size=batch_size, + ) + assert isinstance(annotated_docs, list) + annotated_doc = CoreferenceDocument.concatenated(annotated_docs) + all_annotated_docs.append(annotated_doc) + + mention_pre, mention_rec, mention_f1 = score_mention_detection( + all_annotated_docs, test_dataset.documents + ) + for metric_key, score in [ + ("precision", mention_pre), + ("recall", mention_rec), + ("f1", mention_f1), + ]: + print(f"mention.{metric_key}={score}") + _run.log_scalar(f"mention.{metric_key}", score) scores = score_coref_predictions(all_annotated_docs, test_dataset.documents) for key, score_dict in scores.items(): diff --git a/tibert/run_train.py b/tibert/run_train.py index 10c1e0c..d6ea3a3 100644 --- a/tibert/run_train.py +++ b/tibert/run_train.py @@ -14,6 +14,7 @@ load_train_checkpoint, predict_coref, score_coref_predictions, + score_mention_detection, ) from tibert.bertcoref import CoreferenceDataset, load_democrat_dataset @@ -145,6 +146,17 @@ def main( ) assert isinstance(annotated_docs, list) + mention_pre, mention_rec, mention_f1 = score_mention_detection( + annotated_docs, test_dataset.documents + ) + for metric_key, score in [ + ("precision", mention_pre), + ("recall", mention_rec), + ("f1", mention_f1), + ]: + print(f"mention.{metric_key}={score}") + _run.log_scalar(f"mention.{metric_key}", score) + metrics = score_coref_predictions(annotated_docs, test_dataset.documents) print(metrics) for key, score_dict in metrics.items():