Skip to content

Commit

Permalink
comment on torch.compile
Browse files Browse the repository at this point in the history
  • Loading branch information
fattorib committed Nov 11, 2024
1 parent f53483a commit 4feda35
Showing 1 changed file with 1 addition and 0 deletions.
1 change: 1 addition & 0 deletions mamba/fused_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def backward(ctx, grad_output):
return grad_output * weight_grad, grad_output * act_grad, None


# torch.compile does a good enough job with the kernel here
@torch.compile
def fused_cross_entropy(lm_head_weight, act, labels):
return CrossEntropyLoopedFused.apply(lm_head_weight, act, labels)

0 comments on commit 4feda35

Please sign in to comment.