Skip to content

Commit

Permalink
cr_loss_on_aux_probs
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Jan 7, 2025
1 parent 0119285 commit 53ce0b1
Showing 1 changed file with 46 additions and 16 deletions.
62 changes: 46 additions & 16 deletions users/zeyer/experiments/exp2024_10_16_consistency_reg_ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def cr_ctc_training(*, model: Model, data: Tensor, data_spatial_dim: Dim, target
aed_loss_scale = config.float("aed_loss_scale", 1.0)
use_normalized_loss = config.bool("use_normalized_loss", True)
cr_loss_scale = config.float("cr_loss_scale", 0.2)
cr_loss_on_aux_probs = config.bool("cr_loss_on_aux_probs", False)
aed_loss_bug_fix = config.bool("aed_loss_bug_fix", False)
use_fixed_ctc_grad = config.typed_value("use_fixed_ctc_grad", False)

Expand Down Expand Up @@ -149,6 +150,16 @@ def cr_ctc_training(*, model: Model, data: Tensor, data_spatial_dim: Dim, target
# )
# error.mark_as_loss("label", as_error=True, custom_inv_norm_factor=targets_spatial_dim.get_size_tensor())

if cr_loss_on_aux_probs:
_cr_loss(
f"consistency_{layer_idx}",
aux_log_probs,
branch_dim=branch_dim,
wb_target_dim=model.wb_target_dim,
scale=cr_loss_scale * aux_loss_scales[i],
use_normalized_loss=use_normalized_loss,
)

log_probs = model.log_probs_wb_from_logits(logits)
loss = ctc_loss(
logits=log_probs,
Expand All @@ -165,23 +176,14 @@ def cr_ctc_training(*, model: Model, data: Tensor, data_spatial_dim: Dim, target
use_normalized_loss=use_normalized_loss,
)

assert branch_dim in log_probs.dims
log_probs_a = rf.gather(log_probs, axis=branch_dim, indices=0)
log_probs_b = rf.gather(log_probs, axis=branch_dim, indices=1)
consistency_reg_a = rf.cross_entropy(
estimated=log_probs_a,
estimated_type="log-probs",
target=rf.stop_gradient(rf.exp(log_probs_b)),
axis=model.wb_target_dim,
)
consistency_reg_b = rf.cross_entropy(
estimated=log_probs_b,
estimated_type="log-probs",
target=rf.stop_gradient(rf.exp(log_probs_a)),
axis=model.wb_target_dim,
_cr_loss(
"consistency",
log_probs,
branch_dim=branch_dim,
wb_target_dim=model.wb_target_dim,
scale=cr_loss_scale,
use_normalized_loss=use_normalized_loss,
)
consistency_reg = (consistency_reg_a + consistency_reg_b) * 0.5
consistency_reg.mark_as_loss("consistency", scale=cr_loss_scale, use_normalized_loss=use_normalized_loss)

if model.decoder:
# potentially also other types but just assume
Expand Down Expand Up @@ -229,3 +231,31 @@ def cr_ctc_training(*, model: Model, data: Tensor, data_spatial_dim: Dim, target

cr_ctc_training: TrainDef[Model]
cr_ctc_training.learning_rate_control_error_measure = "ctc"


def _cr_loss(
loss_name: str,
log_probs: Tensor,
*,
branch_dim: Dim,
wb_target_dim: Dim,
scale: float,
use_normalized_loss: bool,
):
assert branch_dim in log_probs.dims
log_probs_a = rf.gather(log_probs, axis=branch_dim, indices=0)
log_probs_b = rf.gather(log_probs, axis=branch_dim, indices=1)
consistency_reg_a = rf.cross_entropy(
estimated=log_probs_a,
estimated_type="log-probs",
target=rf.stop_gradient(rf.exp(log_probs_b)),
axis=wb_target_dim,
)
consistency_reg_b = rf.cross_entropy(
estimated=log_probs_b,
estimated_type="log-probs",
target=rf.stop_gradient(rf.exp(log_probs_a)),
axis=wb_target_dim,
)
consistency_reg = (consistency_reg_a + consistency_reg_b) * 0.5
consistency_reg.mark_as_loss(loss_name, scale=scale, use_normalized_loss=use_normalized_loss)

0 comments on commit 53ce0b1

Please sign in to comment.