Skip to content

Commit

Permalink
Fix blank prior usage
Browse files Browse the repository at this point in the history
  • Loading branch information
mmueller00 committed Dec 4, 2024
1 parent 4692f63 commit 27a00f4
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 108 deletions.
14 changes: 11 additions & 3 deletions users/mueller/experiments/ctc_baseline/ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def py():
train_small = True
with_prior = True
use_sum_criterion = True
aux_loss = False

if train_small:
epochs = 50
Expand Down Expand Up @@ -106,6 +107,7 @@ def py():
alias_name = f"ctc-baseline" + \
(f"-full_sum" if use_sum_criterion else "") + \
(f"-self_training_{self_training_rounds}" if self_training_rounds > 0 else "") + \
(f"-wo_aux_loss" if not aux_loss else "") + \
(f"-dataset_size_{test_self_training_on_small_dataset}" if test_self_training_on_small_dataset > 0 else "") + \
(f"-ds100h" if train_small else "") + \
f"-{vocab}" + \
Expand Down Expand Up @@ -142,6 +144,7 @@ def py():
test_self_training_on_small_dataset = test_self_training_on_small_dataset,
with_prior = with_prior,
use_sum_criterion=use_sum_criterion,
aux_loss=aux_loss
)


Expand Down Expand Up @@ -176,6 +179,7 @@ def train_exp(
test_self_training_on_small_dataset: int = 0,
with_prior: bool = False,
use_sum_criterion: bool = False,
aux_loss: bool = False,
) -> Optional[ModelWithCheckpoints]:
"""
Train experiment
Expand Down Expand Up @@ -274,7 +278,9 @@ def train_exp(
config_self["lm_path"] = lm_path

init_checkpoint = model_with_checkpoint[i].get_last_fixed_epoch().checkpoint

# config_self.pop("__num_processes")
if not aux_loss:
config_self.pop("aux_loss_layers")
model_with_checkpoint.append(train(
prefix_self_training,
task=task,
Expand All @@ -287,7 +293,7 @@ def train_exp(
num_epochs=num_epochs,
gpu_mem=gpu_mem,
num_processes=num_processes,
time_rqmt=time_rqmt if time_rqmt else 132,
time_rqmt=time_rqmt if time_rqmt else (312 if use_sum_criterion else 156),
))
train_job = model_with_checkpoint[i + 1].get_training_job()
if env_updates:
Expand Down Expand Up @@ -832,7 +838,7 @@ def ctc_sum_training(*, model: Model, data: rf.Tensor, data_spatial_dim: Dim, lm

# torch.autograd.set_detect_anomaly(True)

def _calc_log_prior(log_probs: torch.Tensor, lengths: torch.Tensor) -> Tensor:
def _calc_log_prior(log_probs: torch.Tensor, lengths: torch.Tensor) -> torch.Tensor:
lengths = lengths.to(log_probs.device)
assert lengths.size(0) == log_probs.size(0), "Prior calculation batch lengths are not the same (full_sum)!"

Expand All @@ -851,6 +857,8 @@ def _calc_log_prior(log_probs: torch.Tensor, lengths: torch.Tensor) -> Tensor:

with torch.no_grad():
assert log_mean_probs.exp().sum().allclose(torch.tensor(1.0, device=log_mean_probs.device)), f"Prior probs do not sum to 1.0, but to {log_mean_probs.exp().sum()}"
if log_mean_probs.isclose(torch.tensor([0.0], device=log_probs.device)).any() or log_mean_probs.isinf().any() or log_mean_probs.isnan().any():
print("Prior probs contain inf or nan or 0 values!", log_mean_probs, log_mean_probs.exp())

return log_mean_probs

Expand Down
Loading

0 comments on commit 27a00f4

Please sign in to comment.