diff --git a/raptgen/models.py b/raptgen/models.py index 4121148..d174d68 100644 --- a/raptgen/models.py +++ b/raptgen/models.py @@ -33,6 +33,7 @@ def train(epochs, model, train_loader, test_loader, optimizer, loss_fn=None, dev csv_filename = model_str.replace(".mdl", ".csv") if loss_fn == profile_hmm_loss_fn and force_matching: logger.info(f"force till {force_epochs}") + max_beta = beta patient = 0 losses = [] test_losses = [] @@ -44,7 +45,9 @@ def train(epochs, model, train_loader, test_loader, optimizer, loss_fn=None, dev description = "" for epoch in range(1, epochs + 1): if beta_schedule and epoch < threshold: - beta = epoch / threshold + beta = epoch / threshold * max_beta + else: + beta = max_beta model.train() train_loss = 0 for data in train_loader: