@@ -478,6 +478,14 @@ def train(self, log=False):
478478 docs_ids = list (range (len (docs )))
479479 avg_spans = docs .avg_span
480480
481+ # for a brand new model, we set the zeros prediction to all 0 if the dataset has no zeros
482+ training_has_zeros = any ('is_zero' in doc for doc in docs )
483+ if not training_has_zeros :
484+ logger .info ("No zeros found in the dataset. The zeros predictor will set to 0" )
485+ if self .epochs_trained == 0 :
486+ # new model, set it to always predict not-zero
487+ self .disable_zeros_predictor ()
488+
481489 best_f1 = None
482490 for epoch in range (self .epochs_trained , self .config .train_epochs ):
483491 self .training = True
@@ -500,7 +508,7 @@ def train(self, log=False):
500508
501509 res = self .run (doc )
502510
503- if res .zero_scores .size (0 ) == 0 :
511+ if res .zero_scores .size (0 ) == 0 or not training_has_zeros :
504512 z_loss = 0.0 # since there are no corefs
505513 else :
506514 is_zero = doc .get ("is_zero" )
@@ -522,7 +530,7 @@ def train(self, log=False):
522530
523531 running_c_loss += c_loss .item ()
524532 running_s_loss += s_loss .item ()
525- if res .zero_scores .size (0 ) != 0 :
533+ if res .zero_scores .size (0 ) != 0 and training_has_zeros :
526534 running_z_loss += z_loss .item ()
527535
528536 # log every 100 docs
@@ -531,7 +539,7 @@ def train(self, log=False):
531539 'train_c_loss' : c_loss .item (),
532540 'train_s_loss' : s_loss .item (),
533541 }
534- if res .zero_scores .size (0 ) != 0 :
542+ if res .zero_scores .size (0 ) != 0 and training_has_zeros :
535543 logged ['train_z_loss' ] = z_loss .item ()
536544 wandb .log (logged )
537545
@@ -666,6 +674,8 @@ def _build_model(self, foundation_cache):
666674 nn .ReLU (),
667675 nn .Linear (bert_emb , 1 )
668676 ).to (self .config .device )
677+ if not hasattr (self .config , 'use_zeros' ) or not self .config .use_zeros :
678+ self .disable_zeros_predictor ()
669679
670680 self .trainable : Dict [str , torch .nn .Module ] = {
671681 "bert" : self .bert , "we" : self .we ,
@@ -674,6 +684,10 @@ def _build_model(self, foundation_cache):
674684 "sp" : self .sp , "zeros_predictor" : self .zeros_predictor
675685 }
676686
687+ def disable_zeros_predictor (self ):
688+ nn .init .zeros_ (self .zeros_predictor [- 1 ].weight )
689+ nn .init .zeros_ (self .zeros_predictor [- 1 ].bias )
690+
677691 def _build_optimizers (self ):
678692 n_docs = len (self ._get_docs (self .config .train_data ))
679693 self .optimizers : Dict [str , torch .optim .Optimizer ] = {}
0 commit comments