From 7a5feb6a466ee8b3a0cf2aa6cd583c58c7363fde Mon Sep 17 00:00:00 2001 From: Houjun Liu Date: Wed, 18 Jun 2025 21:22:13 +0800 Subject: [PATCH 1/9] corefud 1.3 corpus support - new language groups for coref 1.3 - handling of underscore forms --- stanza/models/common/doc.py | 11 +++ .../utils/datasets/coref/convert_udcoref.py | 85 +++++++++++++------ 2 files changed, 68 insertions(+), 28 deletions(-) diff --git a/stanza/models/common/doc.py b/stanza/models/common/doc.py index 9e0a0746a5..0ffb62bd9c 100644 --- a/stanza/models/common/doc.py +++ b/stanza/models/common/doc.py @@ -740,6 +740,17 @@ def empty_words(self, value): """ Set the list of words for this sentence. """ self._empty_words = value + @property + def all_words(self): + """ Access the list of words + empty words for this sentence. """ + words = self._words + empty_words = self._empty_words + + all = sorted(words + empty_words, key=lambda x:(x.id,) + if isinstance(x.id, int) else x.id) + + return all + @property def ents(self): """ Access the list of entities in this sentence. """ diff --git a/stanza/utils/datasets/coref/convert_udcoref.py b/stanza/utils/datasets/coref/convert_udcoref.py index 72b0e8c18a..8ef3186628 100644 --- a/stanza/utils/datasets/coref/convert_udcoref.py +++ b/stanza/utils/datasets/coref/convert_udcoref.py @@ -35,9 +35,9 @@ def process_documents(docs, augment=False): # extract the entities # get sentence words and lengths - sentences = [[j.text for j in i.words] + sentences = [[j.text for j in i.all_words] for i in doc.sentences] - sentence_lens = [len(x.words) for x in doc.sentences] + sentence_lens = [len(x.all_words) for x in doc.sentences] cased_words = [] for x in sentences: @@ -56,13 +56,13 @@ def process_documents(docs, augment=False): # TODO: does SD vs UD matter? deprel = [] for sentence in doc.sentences: - for word in sentence.words: + for word in sentence.all_words: deprel.append(word.deprel) - if word.head == 0: + if not word.head or word.head == 0: heads.append("null") else: heads.append(word.head - 1 + word_total) - word_total += len(sentence.words) + word_total += len(sentence.all_words) span_clusters = defaultdict(list) word_clusters = defaultdict(list) @@ -75,7 +75,7 @@ def process_documents(docs, augment=False): misc = [[k.split("=") for k in j if k.split("=")[0] == "Entity"] - for i in parsed_sentence.words + for i in parsed_sentence.all_words for j in [i.misc.split("|") if i.misc else []]] # and extract the Entity entry values entities = [i[0][1] if len(i) > 0 else None for i in misc] @@ -112,34 +112,43 @@ def process_documents(docs, augment=False): for k, v in final_refs.items(): for i in v: coref_spans.append([int(k), i[0], i[1]]) - sentence_upos = [x.upos for x in parsed_sentence.words] - sentence_heads = [x.head - 1 if x.head > 0 else None for x in parsed_sentence.words] + sentence_upos = [x.upos for x in parsed_sentence.all_words] + sentence_heads = [x.head - 1 if x.head and x.head > 0 else None for x in parsed_sentence.all_words] + for span in coref_spans: # input is expected to be start word, end word + 1 # counting from 0 # whereas the OntoNotes coref_span is [start_word, end_word] inclusive span_start = span[1] + word_total span_end = span[2] + word_total + 1 - candidate_head = find_cconj_head(sentence_heads, sentence_upos, span[1], span[2]+1) - if candidate_head is None: - for candidate_head in range(span[1], span[2] + 1): - # stanza uses 0 to mark the head, whereas OntoNotes is counting - # words from 0, so we have to subtract 1 from the stanza heads - #print(span, candidate_head, parsed_sentence.words[candidate_head].head - 1) - # treat the head of the phrase as the first word that has a head outside the phrase - if (parsed_sentence.words[candidate_head].head - 1 < span[1] or - parsed_sentence.words[candidate_head].head - 1 > span[2]): - break - else: - # if none have a head outside the phrase (circular??) - # then just take the first word - candidate_head = span[1] + # if its a zero coref (i.e. coref, but the head in None), we call + # the beginning of the span (i.e. the zero itself) the head + + if not sentence_heads[span_start-1]: + candidate_head = span[1] + else: + candidate_head = find_cconj_head(sentence_heads, sentence_upos, span[1], span[2]+1) + if candidate_head is None: + for candidate_head in range(span[1], span[2] + 1): + # stanza uses 0 to mark the head, whereas OntoNotes is counting + # words from 0, so we have to subtract 1 from the stanza heads + #print(span, candidate_head, parsed_sentence.words[candidate_head].head - 1) + # treat the head of the phrase as the first word that has a head outside the phrase + if parsed_sentence.all_words[candidate_head].head and ( + parsed_sentence.all_words[candidate_head].head - 1 < span[1] or + parsed_sentence.all_words[candidate_head].head - 1 > span[2] + ): + break + else: + # if none have a head outside the phrase (circular??) + # then just take the first word + candidate_head = span[1] #print("----> %d" % candidate_head) candidate_head += word_total span_clusters[span[0]].append((span_start, span_end)) word_clusters[span[0]].append(candidate_head) head2span.append((candidate_head, span_start, span_end)) - word_total += len(parsed_sentence.words) + word_total += len(parsed_sentence.all_words) span_clusters = sorted([sorted(values) for _, values in span_clusters.items()]) word_clusters = sorted([sorted(values) for _, values in word_clusters.items()]) head2span = sorted(head2span) @@ -172,7 +181,7 @@ def process_dataset(short_name, coref_output_path, split_test, train_files, dev_ for load in filenames: lang = load.split("/")[-1].split("_")[0] print("Ingesting %s from %s of lang %s" % (section, load, lang)) - docs = CoNLL.conll2multi_docs(load) + docs = CoNLL.conll2multi_docs(load, ignore_gapping=False) print(" Ingested %d documents" % len(docs)) if split_test and section == 'train': test_section = [] @@ -216,7 +225,7 @@ def process_dataset(short_name, coref_output_path, split_test, train_files, dev_ json.dump(converted_section, fout, indent=2) def get_dataset_by_language(coref_input_path, langs): - conll_path = os.path.join(coref_input_path, "CorefUD-1.2-public", "data") + conll_path = os.path.join(coref_input_path, "CorefUD-1.3-public", "data") train_filenames = [] dev_filenames = [] for lang in langs: @@ -242,9 +251,9 @@ def main(): coref_output_path = paths['COREF_DATA_DIR'] if args.project: - if args.project == 'slavic': - project = "slavic_udcoref" - langs = ('Polish', 'Russian', 'Czech') + if args.project == 'baltoslavic': + project = "baltoslavic_udcoref" + langs = ('Polish', 'Russian', 'Czech', 'Old_Church_Slavonic', 'Lithuanian') train_filenames, dev_filenames = get_dataset_by_language(coref_input_path, langs) elif args.project == 'hungarian': project = "hu_udcoref" @@ -262,6 +271,26 @@ def main(): project = "norwegian_udcoref" langs = ('Norwegian',) train_filenames, dev_filenames = get_dataset_by_language(coref_input_path, langs) + elif args.project == 'turkish': + project = "turkish_udcoref" + langs = ('Turkish',) + train_filenames, dev_filenames = get_dataset_by_language(coref_input_path, langs) + elif args.project == 'korean': + project = "korean_udcoref" + langs = ('Korean',) + train_filenames, dev_filenames = get_dataset_by_language(coref_input_path, langs) + elif args.project == 'hindi': + project = "hindi_udcoref" + langs = ('Hindi',) + train_filenames, dev_filenames = get_dataset_by_language(coref_input_path, langs) + elif args.project == 'ancient_greek': + project = "ancient_greek_udcoref" + langs = ('Ancient_Greek',) + train_filenames, dev_filenames = get_dataset_by_language(coref_input_path, langs) + elif args.project == 'ancient_hebrew': + project = "ancient_hebrew_udcoref" + langs = ('Ancient_Hebrew',) + train_filenames, dev_filenames = get_dataset_by_language(coref_input_path, langs) else: project = args.directory conll_path = os.path.join(coref_input_path, project) From 59bce53454f689094e3292f0c2b26537bc6ad0b5 Mon Sep 17 00:00:00 2001 From: Houjun Liu Date: Fri, 20 Jun 2025 13:16:02 +0800 Subject: [PATCH 2/9] model changes to support underscore innference --- stanza/models/coref/const.py | 2 + stanza/models/coref/dataset.py | 16 +++ stanza/models/coref/model.py | 68 ++++++++++-- stanza/models/coref/utils.py | 55 ++++++++++ .../utils/datasets/coref/convert_udcoref.py | 103 ++++++++++++++---- 5 files changed, 215 insertions(+), 29 deletions(-) diff --git a/stanza/models/coref/const.py b/stanza/models/coref/const.py index 931eee1229..479ee5feea 100644 --- a/stanza/models/coref/const.py +++ b/stanza/models/coref/const.py @@ -25,3 +25,5 @@ class CorefResult: rough_scores: torch.Tensor = None # [n_words, n_words] span_scores: torch.Tensor = None # [n_heads, n_words, 2] span_y: Tuple[torch.Tensor, torch.Tensor] = None # [n_heads] x2 + + zero_scores: torch.Tensor = None diff --git a/stanza/models/coref/dataset.py b/stanza/models/coref/dataset.py index fca7d4e500..7fd81ceaef 100644 --- a/stanza/models/coref/dataset.py +++ b/stanza/models/coref/dataset.py @@ -38,6 +38,9 @@ def __init__(self, path, config, tokenizer): word2subword = [] subwords = [] word_id = [] + nonblank_subwords = [] # a list of subwords, skipping _ + previous_was_blank = [] # was the word before _? + was_blank = False # a flag to set if we saw "_" for i, word in enumerate(doc["cased_words"]): tokenized = self.tokenizer.tokenize(word) if len(tokenized) == 0: @@ -50,9 +53,22 @@ def __init__(self, path, config, tokenizer): word2subword.append((len(subwords), len(subwords) + len(tokenized_word))) subwords.extend(tokenized_word) word_id.extend([i] * len(tokenized_word)) + if word == "_": + was_blank = True + else: + nonblank_subwords.extend(tokenized_word) + previous_was_blank.extend( + [True if was_blank else False]+[False]*(len(tokenized_word)-1) + ) + was_blank = False + + doc["nonblank_subwords"] = nonblank_subwords + doc["blank_prefix"] = previous_was_blank + doc["word2subword"] = word2subword doc["subwords"] = subwords doc["word_id"] = word_id + self.__out.append(doc) logger.info("Loaded %d docs from %s.", len(data_f), path) diff --git a/stanza/models/coref/model.py b/stanza/models/coref/model.py index 69d6a8e102..384135eb45 100644 --- a/stanza/models/coref/model.py +++ b/stanza/models/coref/model.py @@ -33,6 +33,7 @@ from stanza.models.coref.rough_scorer import RoughScorer from stanza.models.coref.span_predictor import SpanPredictor from stanza.models.coref.utils import GraphNode +from stanza.models.coref.utils import sigmoid_focal_loss from stanza.models.coref.word_encoder import WordEncoder from stanza.models.coref.dataset import CorefDataset from stanza.models.coref.tokenizer_customization import * @@ -41,6 +42,8 @@ from stanza.models.common.foundation_cache import load_bert, load_bert_with_peft, NoTransformerFoundationCache from stanza.models.common.peft_config import build_peft_wrapper, load_peft_wrapper +import torch.nn as nn + logger = logging.getLogger('stanza') class CorefModel: # pylint: disable=too-many-instance-attributes @@ -140,6 +143,8 @@ def evaluate(self, running_loss = 0.0 s_correct = 0 s_total = 0 + z_correct = 0 + z_total = 0 with conll.open_(self.config, self.epochs_trained, data_split) \ as (gold_f, pred_f): @@ -150,13 +155,21 @@ def evaluate(self, # want to test evaluation on one language continue - res = self.run(doc) + res = self.run(doc, True) + # measure zero prediction accuracy + zero_targets = torch.tensor(doc["is_zero"], device=res.zero_scores.device) + zero_preds = (res.zero_scores > 0).view(-1).to(zero_targets.dtype) + z_correct += (zero_preds == zero_targets).sum().item() + z_total += zero_targets.numel() if (res.coref_y.argmax(dim=1) == 1).all(): logger.warning(f"EVAL: skipping document with no corefs...") continue running_loss += self._coref_criterion(res.coref_scores, res.coref_y).item() + if res.word_clusters is None or res.span_clusters is None: + logger.warning(f"EVAL: skipping document with no clusters...") + continue if res.span_y: pred_starts = res.span_scores[:, :, 0].argmax(dim=1) @@ -191,8 +204,10 @@ def evaluate(self, f" f1: {s_lea[0]:.5f}," f" p: {s_lea[1]:.5f}," f" r: {s_lea[2]:<.5f}" + f" | ZA: {z_correct / z_total:<.5f}" ) logger.info(f"CoNLL-2012 3-Score Average : {w_checker.bakeoff:.5f}") + logger.info(f"Zero prediction accuracy: {z_correct / z_total:.5f}") return (running_loss / len(docs), *s_checker.total_lea, *w_checker.total_lea, *s_checker.mbc, *w_checker.mbc, w_checker.bakeoff, s_checker.bakeoff) @@ -332,6 +347,7 @@ def load_model(path: str, def run(self, # pylint: disable=too-many-locals doc: Doc, + use_gold_spans_for_zeros = False ) -> CorefResult: """ This is a massive method, but it made sense to me to not split it into @@ -380,16 +396,27 @@ def run(self, # pylint: disable=too-many-locals res.coref_y = self._get_ground_truth( cluster_ids, top_indices, (top_rough_scores > float("-inf")), self.config.clusters_starts_are_singletons, - self.config.singletons) + self.config.singletons + ) - res.word_clusters = self._clusterize(doc, res.coref_scores, top_indices, - self.config.singletons) + res.word_clusters = self._clusterize( + doc, res.coref_scores, top_indices, + self.config.singletons + ) res.span_scores, res.span_y = self.sp.get_training_data(doc, words) if not self.training: res.span_clusters = self.sp.predict(doc, words, res.word_clusters) + if not self.training and not use_gold_spans_for_zeros: + zero_words = words[[word_id + for cluster in res.word_clusters + for word_id in cluster]] + else: + zero_words = words[[i[0] for i in sorted(doc["head2span"])]] + res.zero_scores = self.zeros_predictor(zero_words) + return res def save_weights(self, save_path=None, save_optimizers=True): @@ -454,6 +481,7 @@ def train(self, log=False): self.log_norms() running_c_loss = 0.0 running_s_loss = 0.0 + running_z_loss = 0.0 random.shuffle(docs_ids) pbar = tqdm(docs_ids, unit="docs", ncols=0) for doc_indx, doc_id in enumerate(pbar): @@ -468,6 +496,14 @@ def train(self, log=False): res = self.run(doc) + if res.zero_scores.size(0) == 0: + z_loss = 0.0 # since there are no corefs + else: + z_loss = sigmoid_focal_loss(res.zero_scores.squeeze(-1), + (torch.tensor(doc["is_zero"]) + .to(res.zero_scores.device).float()), + reduction="mean") + c_loss = self._coref_criterion(res.coref_scores, res.coref_y) if res.span_y: @@ -478,18 +514,25 @@ def train(self, log=False): del res - (c_loss + s_loss).backward() + (c_loss + s_loss + z_loss).backward() running_c_loss += c_loss.item() running_s_loss += s_loss.item() + if z_loss: + running_z_loss += z_loss.item() # log every 100 docs if log and doc_indx % 100 == 0: - wandb.log({'train_c_loss': c_loss.item(), - 'train_s_loss': s_loss.item()}) + logged = { + 'train_c_loss': c_loss.item(), + 'train_s_loss': s_loss.item(), + } + if z_loss: + logged['train_z_loss'] = z_loss.item() + wandb.log(logged) - del c_loss, s_loss + del c_loss, s_loss, z_loss for optim in self.optimizers.values(): optim.step() @@ -501,6 +544,7 @@ def train(self, log=False): f" {doc['document_id']:26}" f" c_loss: {running_c_loss / (pbar.n + 1):<.5f}" f" s_loss: {running_s_loss / (pbar.n + 1):<.5f}" + f" z_loss: {running_z_loss / (pbar.n + 1):<.5f}" ) self.epochs_trained += 1 @@ -614,12 +658,17 @@ def _build_model(self, foundation_cache): self.we = WordEncoder(bert_emb, self.config).to(self.config.device) self.rough_scorer = RoughScorer(bert_emb, self.config).to(self.config.device) self.sp = SpanPredictor(bert_emb, self.config.sp_embedding_size).to(self.config.device) + self.zeros_predictor = nn.Sequential( + nn.Linear(bert_emb, bert_emb), + nn.ReLU(), + nn.Linear(bert_emb, 1) + ).to(self.config.device) self.trainable: Dict[str, torch.nn.Module] = { "bert": self.bert, "we": self.we, "rough_scorer": self.rough_scorer, "pw": self.pw, "a_scorer": self.a_scorer, - "sp": self.sp + "sp": self.sp, "zeros_predictor": self.zeros_predictor } def _build_optimizers(self): @@ -785,4 +834,3 @@ def _set_training(self, value: bool): self._training = value for module in self.trainable.values(): module.train(self._training) - diff --git a/stanza/models/coref/utils.py b/stanza/models/coref/utils.py index 027308a31c..af6ad1963d 100644 --- a/stanza/models/coref/utils.py +++ b/stanza/models/coref/utils.py @@ -3,6 +3,7 @@ from typing import List, Set import torch +import torch.nn.functional as F from stanza.models.coref.const import EPSILON @@ -33,3 +34,57 @@ def add_dummy(tensor: torch.Tensor, eps: bool = False): else: dummy = torch.full(shape, EPSILON, **kwargs) # type: ignore return torch.cat((dummy, tensor), dim=1) + +def sigmoid_focal_loss( + inputs: torch.Tensor, + targets: torch.Tensor, + alpha: float = 0.25, + gamma: float = 2, + reduction: str = "none", +) -> torch.Tensor: + """ + Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. + + Args: + inputs (Tensor): A float tensor of arbitrary shape. + The predictions for each example. + targets (Tensor): A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + alpha (float): Weighting factor in range [0, 1] to balance + positive vs negative examples or -1 for ignore. Default: ``0.25``. + gamma (float): Exponent of the modulating factor (1 - p_t) to + balance easy vs hard examples. Default: ``2``. + reduction (string): ``'none'`` | ``'mean'`` | ``'sum'`` + ``'none'``: No reduction will be applied to the output. + ``'mean'``: The output will be averaged. + ``'sum'``: The output will be summed. Default: ``'none'``. + Returns: + Loss tensor with the reduction option applied. + """ + # Original implementation from https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py + + if not (0 <= alpha <= 1) and alpha != -1: + raise ValueError(f"Invalid alpha value: {alpha}. alpha must be in the range [0,1] or -1 for ignore.") + + p = torch.sigmoid(inputs) + ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") + p_t = p * targets + (1 - p) * (1 - targets) + loss = ce_loss * ((1 - p_t) ** gamma) + + if alpha >= 0: + alpha_t = alpha * targets + (1 - alpha) * (1 - targets) + loss = alpha_t * loss + + # Check reduction option and return loss accordingly + if reduction == "none": + pass + elif reduction == "mean": + loss = loss.mean() + elif reduction == "sum": + loss = loss.sum() + else: + raise ValueError( + f"Invalid Value for arg 'reduction': '{reduction} \n Supported reduction modes: 'none', 'mean', 'sum'" + ) + return loss diff --git a/stanza/utils/datasets/coref/convert_udcoref.py b/stanza/utils/datasets/coref/convert_udcoref.py index 8ef3186628..2d46e87b92 100644 --- a/stanza/utils/datasets/coref/convert_udcoref.py +++ b/stanza/utils/datasets/coref/convert_udcoref.py @@ -10,6 +10,7 @@ from stanza.utils.conll import CoNLL +import warnings from random import Random import argparse @@ -22,6 +23,7 @@ UDCOREF_ADDN = 0 if not IS_UDCOREF_FORMAT else 1 def process_documents(docs, augment=False): + # docs = sections processed_section = [] for idx, (doc, doc_id, lang) in enumerate(tqdm(docs)): @@ -67,8 +69,10 @@ def process_documents(docs, augment=False): span_clusters = defaultdict(list) word_clusters = defaultdict(list) head2span = [] + is_zero = [] word_total = 0 SPANS = re.compile(r"(\(\w+|[%\w]+\))") + do_ctn = False # if we broke in the loop for parsed_sentence in doc.sentences: # spans regex # parse the misc column, leaving on "Entity" entries @@ -114,8 +118,29 @@ def process_documents(docs, augment=False): coref_spans.append([int(k), i[0], i[1]]) sentence_upos = [x.upos for x in parsed_sentence.all_words] sentence_heads = [x.head - 1 if x.head and x.head > 0 else None for x in parsed_sentence.all_words] + sentence_text = [x.text for x in parsed_sentence.all_words] + + # if "_" in sentence_text and sentence_text.index("_") in [j for i in coref_spans for j in i]: + # import ipdb + # ipdb.set_trace() for span in coref_spans: + zero = False + if sentence_text[span[1]] == "_" and span[1] == span[2]: + is_zero.append([span[0], True]) + zero = True + # oo! thaht's a zero coref, we should merge it forwards + # i.e. we pick the next word as the head! + span = [span[0], span[1]+1, span[2]+1] + # crap! there's two zeros right next to each other + # we are sad and confused so we give up in this case + if len(sentence_text) > span[1] and sentence_text[span[1]] == "_": + warnings.warn("Found two zeros next to each other in sequence; we are confused and therefore giving up.") + do_ctn = True + break + else: + is_zero.append([span[0], False]) + # input is expected to be start word, end word + 1 # counting from 0 # whereas the OntoNotes coref_span is [start_word, end_word] inclusive @@ -124,34 +149,73 @@ def process_documents(docs, augment=False): # if its a zero coref (i.e. coref, but the head in None), we call # the beginning of the span (i.e. the zero itself) the head - if not sentence_heads[span_start-1]: + if zero: candidate_head = span[1] else: - candidate_head = find_cconj_head(sentence_heads, sentence_upos, span[1], span[2]+1) - if candidate_head is None: - for candidate_head in range(span[1], span[2] + 1): - # stanza uses 0 to mark the head, whereas OntoNotes is counting - # words from 0, so we have to subtract 1 from the stanza heads - #print(span, candidate_head, parsed_sentence.words[candidate_head].head - 1) - # treat the head of the phrase as the first word that has a head outside the phrase - if parsed_sentence.all_words[candidate_head].head and ( - parsed_sentence.all_words[candidate_head].head - 1 < span[1] or - parsed_sentence.all_words[candidate_head].head - 1 > span[2] - ): - break - else: - # if none have a head outside the phrase (circular??) - # then just take the first word - candidate_head = span[1] + try: + candidate_head = find_cconj_head(sentence_heads, sentence_upos, span[1], span[2]+1) + except RecursionError: + candidate_head = span[1] + + if candidate_head is None: + for candidate_head in range(span[1], span[2] + 1): + # stanza uses 0 to mark the head, whereas OntoNotes is counting + # words from 0, so we have to subtract 1 from the stanza heads + #print(span, candidate_head, parsed_sentence.words[candidate_head].head - 1) + # treat the head of the phrase as the first word that has a head outside the phrase + if parsed_sentence.all_words[candidate_head].head and ( + parsed_sentence.all_words[candidate_head].head - 1 < span[1] or + parsed_sentence.all_words[candidate_head].head - 1 > span[2] + ): + break + else: + # if none have a head outside the phrase (circular??) + # then just take the first word + candidate_head = span[1] #print("----> %d" % candidate_head) candidate_head += word_total span_clusters[span[0]].append((span_start, span_end)) word_clusters[span[0]].append(candidate_head) head2span.append((candidate_head, span_start, span_end)) + if do_ctn: + break word_total += len(parsed_sentence.all_words) + if do_ctn: + continue span_clusters = sorted([sorted(values) for _, values in span_clusters.items()]) word_clusters = sorted([sorted(values) for _, values in word_clusters.items()]) head2span = sorted(head2span) + is_zero = [i for _,i in sorted(is_zero)] + + # remove zero tokens "_" from cased_words and adjust indices accordingly + zero_positions = [i for i, w in enumerate(cased_words) if w == "_"] + if zero_positions: + old_to_new = {} + new_idx = 0 + for old_idx, w in enumerate(cased_words): + if w != "_": + old_to_new[old_idx] = new_idx + new_idx += 1 + cased_words = [w for w in cased_words if w != "_"] + sent_id = [sent_id[i] for i in sorted(old_to_new.keys())] + deprel = [deprel[i] for i in sorted(old_to_new.keys())] + heads = [heads[i] for i in sorted(old_to_new.keys())] + try: + span_clusters = [ + [(old_to_new[start], old_to_new[end - 1] + 1) for start, end in cluster] + for cluster in span_clusters + ] + except: + warnings.warn("Somehow, we are still coreffering to a zero. This is likely due to multiple zeros on top of each other. We are giving up.") + continue + word_clusters = [ + [old_to_new[h] for h in cluster] + for cluster in word_clusters + ] + head2span = [ + (old_to_new[h], old_to_new[s], old_to_new[e - 1] + 1) + for h, s, e in head2span + ] processed = { "document_id": doc_id, @@ -164,7 +228,8 @@ def process_documents(docs, augment=False): "span_clusters": span_clusters, "word_clusters": word_clusters, "head2span": head2span, - "lang": lang + "lang": lang, + "is_zero": is_zero } processed_section.append(processed) return processed_section @@ -182,6 +247,7 @@ def process_dataset(short_name, coref_output_path, split_test, train_files, dev_ lang = load.split("/")[-1].split("_")[0] print("Ingesting %s from %s of lang %s" % (section, load, lang)) docs = CoNLL.conll2multi_docs(load, ignore_gapping=False) + # sections = docs[:10] print(" Ingested %d documents" % len(docs)) if split_test and section == 'train': test_section = [] @@ -302,4 +368,3 @@ def main(): if __name__ == '__main__': main() - From 9053968a1727a20edf4b41d6e26d816614613a8b Mon Sep 17 00:00:00 2001 From: Houjun Liu Date: Sun, 3 Aug 2025 09:12:02 -0700 Subject: [PATCH 3/9] inference processor for coref --- stanza/pipeline/coref_processor.py | 57 +++++++++++++++++++++++++++++- 1 file changed, 56 insertions(+), 1 deletion(-) diff --git a/stanza/pipeline/coref_processor.py b/stanza/pipeline/coref_processor.py index 3c6bb2e01f..989e96fb51 100644 --- a/stanza/pipeline/coref_processor.py +++ b/stanza/pipeline/coref_processor.py @@ -4,6 +4,7 @@ from stanza.models.common.utils import misc_to_space_after from stanza.models.coref.coref_chain import CorefMention, CorefChain +from stanza.models.common.doc import Word from stanza.pipeline._constants import * from stanza.pipeline.processor import UDProcessor, register_processor @@ -99,8 +100,13 @@ def process(self, document): } coref_input = self._model.build_doc(coref_input) results = self._model.run(coref_input) + + + # Handle zero anaphora - zero_scores is always predicted + zero_nodes_created = self._handle_zero_anaphora(document, results, sent_ids, word_pos) + clusters = [] - for span_cluster in results.span_clusters: + for cluster_idx, span_cluster in enumerate(results.span_clusters): if len(span_cluster) == 0: continue span_cluster = sorted(span_cluster) @@ -144,6 +150,14 @@ def process(self, document): start_word = word_pos[span[0]] end_word = word_pos[span[1]-1] + 1 mentions.append(CorefMention(sent_id, start_word, end_word)) + + # Add zero node mentions to this cluster if any exist + for zero_cluster_idx, zero_sent_id, zero_word_decimal_id in zero_nodes_created: + if zero_cluster_idx == cluster_idx: + # Zero node is a single "word" mention at the decimal position + import math + end_word = math.floor(zero_word_decimal_id) + 1 + mentions.append(CorefMention(zero_sent_id, zero_word_decimal_id, end_word)) representative = mentions[best_span] representative_text = extract_text(document, representative.sentence, representative.start_word, representative.end_word) @@ -152,3 +166,44 @@ def process(self, document): document.coref = clusters return document + + def _handle_zero_anaphora(self, document, results, sent_ids, word_pos): + """Handle zero anaphora by creating zero nodes and updating coreference clusters.""" + if results.zero_scores is None or results.word_clusters is None: + return + + zero_scores = results.zero_scores.squeeze(-1) if results.zero_scores.dim() > 1 else results.zero_scores + + # Flatten word_clusters to get the word indices that correspond to zero_scores + cluster_word_ids = [] + for cluster in results.word_clusters: + cluster_word_ids.extend(cluster) + + # Find indices where zero_scores > 0 + zero_indices = (zero_scores > 0).nonzero(as_tuple=True)[0] + + for zero_idx in zero_indices: + zero_idx = zero_idx.item() + if zero_idx >= len(cluster_word_ids): + continue + + word_idx = cluster_word_ids[zero_idx] + sent_id = sent_ids[word_idx] + word_id = word_pos[word_idx] + + # Create zero node - attach BEFORE the current word + # This means the zero node comes after word_id-1 but before word_id + if word_id > 0: + zero_word_id = (word_id, 1) # attach after word_id-1, before word_id + zero_word = Word(document.sentences[sent_id], { + "text": "_", + "lemma": "_", + "id": zero_word_id + }) + document.sentences[sent_id]._empty_words.append(zero_word) + + # Track this zero node for adding to coreference clusters + cluster_idx, _ = cluster_mapping[zero_idx] + zero_nodes_created.append((cluster_idx, sent_id, word_id + 0.1)) + + return zero_nodes_created From 6c3833d7a90091a1de87ab61ddbf41c7cc3595c0 Mon Sep 17 00:00:00 2001 From: Houjun Liu Date: Sun, 3 Aug 2025 11:00:01 -0700 Subject: [PATCH 4/9] fixes for zero coref inference --- stanza/models/common/doc.py | 18 +++--- stanza/pipeline/coref_processor.py | 98 ++++++++++++++++++++---------- 2 files changed, 77 insertions(+), 39 deletions(-) diff --git a/stanza/models/common/doc.py b/stanza/models/common/doc.py index 0ffb62bd9c..e73fad8d4a 100644 --- a/stanza/models/common/doc.py +++ b/stanza/models/common/doc.py @@ -486,18 +486,22 @@ def coref(self, chains): def _attach_coref_mentions(self, chains): for sentence in self.sentences: - for word in sentence.words: + for word in sentence.all_words: word.coref_chains = [] for chain in chains: for mention_idx, mention in enumerate(chain.mentions): sentence = self.sentences[mention.sentence] - for word_idx in range(mention.start_word, mention.end_word): - is_start = word_idx == mention.start_word - is_end = word_idx == mention.end_word - 1 - is_representative = mention_idx == chain.representative_index - attachment = CorefAttachment(chain, is_start, is_end, is_representative) - sentence.words[word_idx].coref_chains.append(attachment) + if isinstance(mention.start_word, tuple): + attachment = CorefAttachment(chain, True, True, False) + sentence._empty_words[mention.start_word[1]-1].coref_chains.append(attachment) + else: + for word_idx in range(mention.start_word, mention.end_word): + is_start = word_idx == mention.start_word + is_end = word_idx == mention.end_word - 1 + is_representative = mention_idx == chain.representative_index + attachment = CorefAttachment(chain, is_start, is_end, is_representative) + sentence.words[word_idx].coref_chains.append(attachment) def reindex_sentences(self, start_index): for sent_id, sentence in zip(range(start_index, start_index + len(self.sentences)), self.sentences): diff --git a/stanza/pipeline/coref_processor.py b/stanza/pipeline/coref_processor.py index 989e96fb51..b85e60f74b 100644 --- a/stanza/pipeline/coref_processor.py +++ b/stanza/pipeline/coref_processor.py @@ -9,6 +9,8 @@ from stanza.pipeline._constants import * from stanza.pipeline.processor import UDProcessor, register_processor +import torch + def extract_text(document, sent_id, start_word, end_word): sentence = document.sentences[sent_id] tokens = [] @@ -128,6 +130,11 @@ def process(self, document): best_span = None max_propn = 0 for span_idx, span in enumerate(span_cluster): + word_idx = results.word_clusters[cluster_idx][span_idx] + is_zero = zero_nodes_created.get((cluster_idx, word_idx)) + if is_zero: + continue + sent_id = sent_ids[span[0]] sentence = sentences[sent_id] start_word = word_pos[span[0]] @@ -145,21 +152,33 @@ def process(self, document): max_propn = num_propn mentions = [] - for span in span_cluster: - sent_id = sent_ids[span[0]] - start_word = word_pos[span[0]] - end_word = word_pos[span[1]-1] + 1 - mentions.append(CorefMention(sent_id, start_word, end_word)) - - # Add zero node mentions to this cluster if any exist - for zero_cluster_idx, zero_sent_id, zero_word_decimal_id in zero_nodes_created: - if zero_cluster_idx == cluster_idx: - # Zero node is a single "word" mention at the decimal position - import math - end_word = math.floor(zero_word_decimal_id) + 1 - mentions.append(CorefMention(zero_sent_id, zero_word_decimal_id, end_word)) - representative = mentions[best_span] - representative_text = extract_text(document, representative.sentence, representative.start_word, representative.end_word) + for span_idx, span in enumerate(span_cluster): + word_idx = results.word_clusters[cluster_idx][span_idx] + is_zero = zero_nodes_created.get((cluster_idx, word_idx)) + if is_zero: + (sent_id, zero_word_id) = is_zero + # if the word id is a tuple, it will be attached + # to the zero + mentions.append( + CorefMention( + sent_id, + zero_word_id, + zero_word_id + ) + ) + else: + sent_id = sent_ids[span[0]] + start_word = word_pos[span[0]] + end_word = word_pos[span[1]-1] + 1 + mentions.append(CorefMention(sent_id, start_word, end_word)) + + # if we ended up with no best span, then our "representative text" + # is just underscore + if best_span: + representative = mentions[best_span] + representative_text = extract_text(document, representative.sentence, representative.start_word, representative.end_word) + else: + representative_text = "_" chain = CorefChain(len(clusters), mentions, representative_text, best_span) clusters.append(chain) @@ -173,15 +192,26 @@ def _handle_zero_anaphora(self, document, results, sent_ids, word_pos): return zero_scores = results.zero_scores.squeeze(-1) if results.zero_scores.dim() > 1 else results.zero_scores + is_zero = [] # Flatten word_clusters to get the word indices that correspond to zero_scores cluster_word_ids = [] - for cluster in results.word_clusters: + cluster_mapping = {} + counter = 0 + for indx, cluster in enumerate(results.word_clusters): + for _ in range(len(cluster)): + cluster_mapping[counter] = indx + counter += 1 cluster_word_ids.extend(cluster) # Find indices where zero_scores > 0 - zero_indices = (zero_scores > 0).nonzero(as_tuple=True)[0] - + print(zero_scores) + zero_indices = (zero_scores > 0.0).nonzero() + + # this dict maps (cluster_id, word_id) to (cluster_id, start, end) + # which overrides span_clusters + zero_to_coref = {} + for zero_idx in zero_indices: zero_idx = zero_idx.item() if zero_idx >= len(cluster_word_ids): @@ -193,17 +223,21 @@ def _handle_zero_anaphora(self, document, results, sent_ids, word_pos): # Create zero node - attach BEFORE the current word # This means the zero node comes after word_id-1 but before word_id - if word_id > 0: - zero_word_id = (word_id, 1) # attach after word_id-1, before word_id - zero_word = Word(document.sentences[sent_id], { - "text": "_", - "lemma": "_", - "id": zero_word_id - }) - document.sentences[sent_id]._empty_words.append(zero_word) - - # Track this zero node for adding to coreference clusters - cluster_idx, _ = cluster_mapping[zero_idx] - zero_nodes_created.append((cluster_idx, sent_id, word_id + 0.1)) - - return zero_nodes_created + zero_word_id = ( + word_id, + len(document.sentences[sent_id]._empty_words)+1 + ) # attach after word_id-1, before word_id + zero_word = Word(document.sentences[sent_id], { + "text": "_", + "lemma": "_", + "id": zero_word_id + }) + document.sentences[sent_id]._empty_words.append(zero_word) + + # Track this zero node for adding to coreference clusters + cluster_idx = cluster_mapping[zero_idx] + zero_to_coref[(cluster_idx, word_idx)] = ( + sent_id, zero_word_id + ) + + return zero_to_coref From 7b7e69af79c86453b4df69ce16c3cd20343ec163 Mon Sep 17 00:00:00 2001 From: Houjun Liu Date: Tue, 5 Aug 2025 21:15:23 -0700 Subject: [PATCH 5/9] small debugging patches to support empty node prediction --- stanza/models/common/doc.py | 6 +++--- stanza/models/coref/dataset.py | 14 -------------- stanza/models/coref/model.py | 9 +++------ stanza/pipeline/coref_processor.py | 3 +-- stanza/utils/datasets/coref/convert_udcoref.py | 6 +++--- 5 files changed, 10 insertions(+), 28 deletions(-) diff --git a/stanza/models/common/doc.py b/stanza/models/common/doc.py index e73fad8d4a..df5a72a8d9 100644 --- a/stanza/models/common/doc.py +++ b/stanza/models/common/doc.py @@ -750,10 +750,10 @@ def all_words(self): words = self._words empty_words = self._empty_words - all = sorted(words + empty_words, key=lambda x:(x.id,) - if isinstance(x.id, int) else x.id) + all_words = sorted(words + empty_words, + key=lambda x:(x.id,) if isinstance(x.id, int) else x.id) - return all + return all_words @property def ents(self): diff --git a/stanza/models/coref/dataset.py b/stanza/models/coref/dataset.py index 7fd81ceaef..3efe90379d 100644 --- a/stanza/models/coref/dataset.py +++ b/stanza/models/coref/dataset.py @@ -38,9 +38,6 @@ def __init__(self, path, config, tokenizer): word2subword = [] subwords = [] word_id = [] - nonblank_subwords = [] # a list of subwords, skipping _ - previous_was_blank = [] # was the word before _? - was_blank = False # a flag to set if we saw "_" for i, word in enumerate(doc["cased_words"]): tokenized = self.tokenizer.tokenize(word) if len(tokenized) == 0: @@ -53,17 +50,6 @@ def __init__(self, path, config, tokenizer): word2subword.append((len(subwords), len(subwords) + len(tokenized_word))) subwords.extend(tokenized_word) word_id.extend([i] * len(tokenized_word)) - if word == "_": - was_blank = True - else: - nonblank_subwords.extend(tokenized_word) - previous_was_blank.extend( - [True if was_blank else False]+[False]*(len(tokenized_word)-1) - ) - was_blank = False - - doc["nonblank_subwords"] = nonblank_subwords - doc["blank_prefix"] = previous_was_blank doc["word2subword"] = word2subword doc["subwords"] = subwords diff --git a/stanza/models/coref/model.py b/stanza/models/coref/model.py index 384135eb45..e59a5979b2 100644 --- a/stanza/models/coref/model.py +++ b/stanza/models/coref/model.py @@ -512,13 +512,11 @@ def train(self, log=False): else: s_loss = torch.zeros_like(c_loss) - del res - (c_loss + s_loss + z_loss).backward() running_c_loss += c_loss.item() running_s_loss += s_loss.item() - if z_loss: + if res.zero_scores.size(0) != 0: running_z_loss += z_loss.item() # log every 100 docs @@ -527,12 +525,11 @@ def train(self, log=False): 'train_c_loss': c_loss.item(), 'train_s_loss': s_loss.item(), } - if z_loss: + if res.zero_scores.size(0) != 0: logged['train_z_loss'] = z_loss.item() wandb.log(logged) - - del c_loss, s_loss, z_loss + del c_loss, s_loss, z_loss, res for optim in self.optimizers.values(): optim.step() diff --git a/stanza/pipeline/coref_processor.py b/stanza/pipeline/coref_processor.py index b85e60f74b..9f43328e3d 100644 --- a/stanza/pipeline/coref_processor.py +++ b/stanza/pipeline/coref_processor.py @@ -174,7 +174,7 @@ def process(self, document): # if we ended up with no best span, then our "representative text" # is just underscore - if best_span: + if best_span is not None: representative = mentions[best_span] representative_text = extract_text(document, representative.sentence, representative.start_word, representative.end_word) else: @@ -205,7 +205,6 @@ def _handle_zero_anaphora(self, document, results, sent_ids, word_pos): cluster_word_ids.extend(cluster) # Find indices where zero_scores > 0 - print(zero_scores) zero_indices = (zero_scores > 0.0).nonzero() # this dict maps (cluster_id, word_id) to (cluster_id, start, end) diff --git a/stanza/utils/datasets/coref/convert_udcoref.py b/stanza/utils/datasets/coref/convert_udcoref.py index 2d46e87b92..d70c0946f1 100644 --- a/stanza/utils/datasets/coref/convert_udcoref.py +++ b/stanza/utils/datasets/coref/convert_udcoref.py @@ -129,7 +129,7 @@ def process_documents(docs, augment=False): if sentence_text[span[1]] == "_" and span[1] == span[2]: is_zero.append([span[0], True]) zero = True - # oo! thaht's a zero coref, we should merge it forwards + # oo! that's a zero coref, we should merge it forwards # i.e. we pick the next word as the head! span = [span[0], span[1]+1, span[2]+1] # crap! there's two zeros right next to each other @@ -163,7 +163,7 @@ def process_documents(docs, augment=False): # words from 0, so we have to subtract 1 from the stanza heads #print(span, candidate_head, parsed_sentence.words[candidate_head].head - 1) # treat the head of the phrase as the first word that has a head outside the phrase - if parsed_sentence.all_words[candidate_head].head and ( + if (parsed_sentence.all_words[candidate_head].head is not None) and ( parsed_sentence.all_words[candidate_head].head - 1 < span[1] or parsed_sentence.all_words[candidate_head].head - 1 > span[2] ): @@ -205,7 +205,7 @@ def process_documents(docs, augment=False): [(old_to_new[start], old_to_new[end - 1] + 1) for start, end in cluster] for cluster in span_clusters ] - except: + except (KeyError, TypeError) as _: # two errors, either end-1 = -1, or start/end is None warnings.warn("Somehow, we are still coreffering to a zero. This is likely due to multiple zeros on top of each other. We are giving up.") continue word_clusters = [ From 784d8e646be1c13c8ea745efe7319ea18ba44fbb Mon Sep 17 00:00:00 2001 From: John Bauer Date: Fri, 26 Sep 2025 17:52:37 -0700 Subject: [PATCH 6/9] If a saved dataset doesn't have zeros built in to the dataset, treat each text with missing zeros as having no zeros (eg, is_zero == 0 for every position) --- stanza/models/coref/model.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/stanza/models/coref/model.py b/stanza/models/coref/model.py index e59a5979b2..321d1cc958 100644 --- a/stanza/models/coref/model.py +++ b/stanza/models/coref/model.py @@ -157,8 +157,12 @@ def evaluate(self, res = self.run(doc, True) # measure zero prediction accuracy - zero_targets = torch.tensor(doc["is_zero"], device=res.zero_scores.device) - zero_preds = (res.zero_scores > 0).view(-1).to(zero_targets.dtype) + zero_preds = (res.zero_scores > 0).view(-1).to(device=res.zero_scores.device) + is_zero = doc.get("is_zero") + if is_zero is None: + zero_targets = torch.zeros_like(zero_preds, device=zero_preds.device) + else: + zero_targets = torch.tensor(is_zero, device=zero_preds.device) z_correct += (zero_preds == zero_targets).sum().item() z_total += zero_targets.numel() @@ -499,10 +503,12 @@ def train(self, log=False): if res.zero_scores.size(0) == 0: z_loss = 0.0 # since there are no corefs else: - z_loss = sigmoid_focal_loss(res.zero_scores.squeeze(-1), - (torch.tensor(doc["is_zero"]) - .to(res.zero_scores.device).float()), - reduction="mean") + is_zero = doc.get("is_zero") + if is_zero is None: + is_zero = torch.zeros_like(res.zero_scores.squeeze(-1), device=res.zero_scores.device, dtype=torch.float) + else: + is_zero = torch.tensor(is_zero).to(res.zero_scores.device).float() + z_loss = sigmoid_focal_loss(res.zero_scores.squeeze(-1), is_zero, reduction="mean") c_loss = self._coref_criterion(res.coref_scores, res.coref_y) From b0de92ee4e1e494a965fa3eb78e478fbb1c46ff1 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Sat, 27 Sep 2025 00:16:39 -0700 Subject: [PATCH 7/9] Add an option to combine one or more languages when using convert_udcoref --- stanza/utils/datasets/coref/convert_udcoref.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/stanza/utils/datasets/coref/convert_udcoref.py b/stanza/utils/datasets/coref/convert_udcoref.py index d70c0946f1..4aeb42cbb1 100644 --- a/stanza/utils/datasets/coref/convert_udcoref.py +++ b/stanza/utils/datasets/coref/convert_udcoref.py @@ -311,12 +311,17 @@ def main(): group = parser.add_mutually_exclusive_group(required=True) group.add_argument('--directory', type=str, help="the name of the subfolder for data conversion") group.add_argument('--project', type=str, help="Look for and use a set of datasets for data conversion - Slavic or Hungarian") + group.add_argument('--languages', type=str, help="Only use these specific languages from the coref directory") args = parser.parse_args() coref_input_path = paths['COREF_BASE'] coref_output_path = paths['COREF_DATA_DIR'] - if args.project: + if args.languages: + langs = args.languages.split(",") + project = "_".join(langs) + train_filenames, dev_filenames = get_dataset_by_language(coref_input_path, langs) + elif args.project: if args.project == 'baltoslavic': project = "baltoslavic_udcoref" langs = ('Polish', 'Russian', 'Czech', 'Old_Church_Slavonic', 'Lithuanian') From 3a810da22bb76bd1e17c4521e7fbce5ed2eef81e Mon Sep 17 00:00:00 2001 From: John Bauer Date: Tue, 30 Sep 2025 21:12:06 -0700 Subject: [PATCH 8/9] Allow the pipeline to pass along a --coref_log_norms parameter to the coref model --- stanza/pipeline/coref_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stanza/pipeline/coref_processor.py b/stanza/pipeline/coref_processor.py index 9f43328e3d..dab0ab4939 100644 --- a/stanza/pipeline/coref_processor.py +++ b/stanza/pipeline/coref_processor.py @@ -70,7 +70,7 @@ def _set_up_model(self, config, pipeline, device): # (except its config) # TODO: separate any pretrains if possible # TODO: add device parameter to the load mechanism - config_update = {'log_norms': False, + config_update = {'log_norms': config.get('log_norms', False), 'device': device} model = CorefModel.load_model(path=config['model_path'], ignore={"bert_optimizer", "general_optimizer", From a6d7cf7b769028417fbf4e017bc9ab764df20c99 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Wed, 1 Oct 2025 00:08:37 -0700 Subject: [PATCH 9/9] Load models which were trained without zeros to always predict non-zero. Also, don't both training the zero predictor if a dataset has no zeros in it --- stanza/models/coref/config.py | 2 ++ stanza/models/coref/coref_config.toml | 3 +++ stanza/models/coref/model.py | 20 +++++++++++++++++--- 3 files changed, 22 insertions(+), 3 deletions(-) diff --git a/stanza/models/coref/config.py b/stanza/models/coref/config.py index 328431edd5..821fde01a7 100644 --- a/stanza/models/coref/config.py +++ b/stanza/models/coref/config.py @@ -65,3 +65,5 @@ class Config: # pylint: disable=too-many-instance-attributes, too-few-public-me singletons: bool max_train_len: int + use_zeros: bool + diff --git a/stanza/models/coref/coref_config.toml b/stanza/models/coref/coref_config.toml index 968a89d20f..bdca781745 100755 --- a/stanza/models/coref/coref_config.toml +++ b/stanza/models/coref/coref_config.toml @@ -119,6 +119,9 @@ conll_log_dir = "data/conll_logs" # Skip any documents longer than this length max_train_len = 5000 +# if this is set to false, the model will set its zero_predictor to, well, 0 +use_zeros = true + # ============================================================================= # Extra keyword arguments to be passed to bert tokenizers of specified models [DEFAULT.tokenizer_kwargs] diff --git a/stanza/models/coref/model.py b/stanza/models/coref/model.py index 321d1cc958..9a033d03e4 100644 --- a/stanza/models/coref/model.py +++ b/stanza/models/coref/model.py @@ -478,6 +478,14 @@ def train(self, log=False): docs_ids = list(range(len(docs))) avg_spans = docs.avg_span + # for a brand new model, we set the zeros prediction to all 0 if the dataset has no zeros + training_has_zeros = any('is_zero' in doc for doc in docs) + if not training_has_zeros: + logger.info("No zeros found in the dataset. The zeros predictor will set to 0") + if self.epochs_trained == 0: + # new model, set it to always predict not-zero + self.disable_zeros_predictor() + best_f1 = None for epoch in range(self.epochs_trained, self.config.train_epochs): self.training = True @@ -500,7 +508,7 @@ def train(self, log=False): res = self.run(doc) - if res.zero_scores.size(0) == 0: + if res.zero_scores.size(0) == 0 or not training_has_zeros: z_loss = 0.0 # since there are no corefs else: is_zero = doc.get("is_zero") @@ -522,7 +530,7 @@ def train(self, log=False): running_c_loss += c_loss.item() running_s_loss += s_loss.item() - if res.zero_scores.size(0) != 0: + if res.zero_scores.size(0) != 0 and training_has_zeros: running_z_loss += z_loss.item() # log every 100 docs @@ -531,7 +539,7 @@ def train(self, log=False): 'train_c_loss': c_loss.item(), 'train_s_loss': s_loss.item(), } - if res.zero_scores.size(0) != 0: + if res.zero_scores.size(0) != 0 and training_has_zeros: logged['train_z_loss'] = z_loss.item() wandb.log(logged) @@ -666,6 +674,8 @@ def _build_model(self, foundation_cache): nn.ReLU(), nn.Linear(bert_emb, 1) ).to(self.config.device) + if not hasattr(self.config, 'use_zeros') or not self.config.use_zeros: + self.disable_zeros_predictor() self.trainable: Dict[str, torch.nn.Module] = { "bert": self.bert, "we": self.we, @@ -674,6 +684,10 @@ def _build_model(self, foundation_cache): "sp": self.sp, "zeros_predictor": self.zeros_predictor } + def disable_zeros_predictor(self): + nn.init.zeros_(self.zeros_predictor[-1].weight) + nn.init.zeros_(self.zeros_predictor[-1].bias) + def _build_optimizers(self): n_docs = len(self._get_docs(self.config.train_data)) self.optimizers: Dict[str, torch.optim.Optimizer] = {}