Skip to content

Commit

Permalink
Use torch.log1p (#6930)
Browse files Browse the repository at this point in the history
This function provides greater precision than `log(1 + x)` for small
values of `x`.

Found with TorchFix https://github.com/pytorch-labs/torchfix/
  • Loading branch information
kit1980 authored Jan 8, 2025
1 parent c7f3032 commit c41b0c2
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion deepspeed/sequence/fpdt_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def _update_out_and_lse(
block_out = block_out.to(torch.float32)
block_lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1)

new_lse = lse + torch.log(1 + torch.exp(block_lse - lse))
new_lse = lse + torch.log1p(torch.exp(block_lse - lse))

out = torch.exp(lse - new_lse) * out + torch.exp(block_lse - new_lse) * block_out

Expand Down

0 comments on commit c41b0c2

Please sign in to comment.