diff --git a/configs/trainer/deepspeed.yaml b/configs/trainer/deepspeed.yaml index 3c3dc1c..91ca2b6 100644 --- a/configs/trainer/deepspeed.yaml +++ b/configs/trainer/deepspeed.yaml @@ -22,4 +22,4 @@ strategy: cpu_checkpointing: False # Gradient accumulation -accumulate_grad_batches: 2 \ No newline at end of file +accumulate_grad_batches: 1 \ No newline at end of file diff --git a/src/utils/loss.py b/src/utils/loss.py index 2a2f121..c1e37c3 100644 --- a/src/utils/loss.py +++ b/src/utils/loss.py @@ -136,7 +136,7 @@ def diffusion_loss( # Scale by (t**2 + σ**2) / (t + σ)**2 scaling_factor = torch.add(timesteps ** 2, sd_data ** 2) / (torch.mul(timesteps, sd_data) ** 2 + epsilon) - loss_diffusion = scaling_factor * mse + loss_diffusion = scaling_factor.squeeze(-1) * mse # (bs) # Smooth LDDT Loss # if use_smooth_lddt: