Skip to content

Commit 604bcf3

Browse files
committed
Load models which were trained without zeros to always predict non-zero. Also, don't both training the zero predictor if a dataset has no zeros in it
1 parent 4bfedf1 commit 604bcf3

File tree

3 files changed

+22
-3
lines changed

3 files changed

+22
-3
lines changed

stanza/models/coref/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,5 @@ class Config: # pylint: disable=too-many-instance-attributes, too-few-public-me
6565
singletons: bool
6666

6767
max_train_len: int
68+
use_zeros: bool
69+

stanza/models/coref/coref_config.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,9 @@ conll_log_dir = "data/conll_logs"
119119
# Skip any documents longer than this length
120120
max_train_len = 5000
121121

122+
# if this is set to false, the model will set its zero_predictor to, well, 0
123+
use_zeros = true
124+
122125
# =============================================================================
123126
# Extra keyword arguments to be passed to bert tokenizers of specified models
124127
[DEFAULT.tokenizer_kwargs]

stanza/models/coref/model.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,14 @@ def train(self, log=False):
478478
docs_ids = list(range(len(docs)))
479479
avg_spans = docs.avg_span
480480

481+
# for a brand new model, we set the zeros prediction to all 0 if the dataset has no zeros
482+
training_has_zeros = any('is_zero' in doc for doc in docs)
483+
if not training_has_zeros:
484+
logger.info("No zeros found in the dataset. The zeros predictor will set to 0")
485+
if self.epochs_trained == 0:
486+
# new model, set it to always predict not-zero
487+
self.disable_zeros_predictor()
488+
481489
best_f1 = None
482490
for epoch in range(self.epochs_trained, self.config.train_epochs):
483491
self.training = True
@@ -500,7 +508,7 @@ def train(self, log=False):
500508

501509
res = self.run(doc)
502510

503-
if res.zero_scores.size(0) == 0:
511+
if res.zero_scores.size(0) == 0 or not training_has_zeros:
504512
z_loss = 0.0 # since there are no corefs
505513
else:
506514
is_zero = doc.get("is_zero")
@@ -522,7 +530,7 @@ def train(self, log=False):
522530

523531
running_c_loss += c_loss.item()
524532
running_s_loss += s_loss.item()
525-
if res.zero_scores.size(0) != 0:
533+
if res.zero_scores.size(0) != 0 and training_has_zeros:
526534
running_z_loss += z_loss.item()
527535

528536
# log every 100 docs
@@ -531,7 +539,7 @@ def train(self, log=False):
531539
'train_c_loss': c_loss.item(),
532540
'train_s_loss': s_loss.item(),
533541
}
534-
if res.zero_scores.size(0) != 0:
542+
if res.zero_scores.size(0) != 0 and training_has_zeros:
535543
logged['train_z_loss'] = z_loss.item()
536544
wandb.log(logged)
537545

@@ -666,6 +674,8 @@ def _build_model(self, foundation_cache):
666674
nn.ReLU(),
667675
nn.Linear(bert_emb, 1)
668676
).to(self.config.device)
677+
if not hasattr(self.config, 'use_zeros') or not self.config.use_zeros:
678+
self.disable_zeros_predictor()
669679

670680
self.trainable: Dict[str, torch.nn.Module] = {
671681
"bert": self.bert, "we": self.we,
@@ -674,6 +684,10 @@ def _build_model(self, foundation_cache):
674684
"sp": self.sp, "zeros_predictor": self.zeros_predictor
675685
}
676686

687+
def disable_zeros_predictor(self):
688+
nn.init.zeros_(self.zeros_predictor[-1].weight)
689+
nn.init.zeros_(self.zeros_predictor[-1].bias)
690+
677691
def _build_optimizers(self):
678692
n_docs = len(self._get_docs(self.config.train_data))
679693
self.optimizers: Dict[str, torch.optim.Optimizer] = {}

0 commit comments

Comments
 (0)