Skip to content

Commit 73b6e76

Browse files
committed
commit fused_cross_entropy kernel for hawk too
1 parent fba0d37 commit 73b6e76

File tree

2 files changed

+87
-7
lines changed

2 files changed

+87
-7
lines changed

hawk/fused_cross_entropy.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
"""This is a fused cross-entropy and linear layer. Idea is copied
2+
from https://github.com/linkedin/Liger-Kernel who just copied it from
3+
https://github.com/mgmalek/efficient_cross_entropy
4+
"""
5+
6+
import torch
7+
from torch.autograd import Function
8+
from torch.nn import functional as F
9+
10+
11+
class CrossEntropyLoopedFused(Function):
12+
@staticmethod
13+
def forward(ctx, weight: torch.Tensor, act: torch.Tensor, labels: torch.Tensor):
14+
bs = act.shape[0]
15+
weight_grad = torch.zeros_like(weight)
16+
act_grad = torch.empty_like(act)
17+
out_loss = torch.tensor(0.0, device=act.device)
18+
chunksize = 2048
19+
20+
for b in range(0, bs, chunksize):
21+
end_idx = min(b + chunksize, bs)
22+
23+
# Get current batch chunks
24+
act_chunk = act[b:end_idx] # [chunk_size, H]
25+
labels_chunk = labels[b:end_idx] # [chunk_size]
26+
27+
# Compute logits
28+
logits = F.linear(act_chunk, weight) # [chunk_size, V]
29+
30+
# Compute softmax and loss
31+
max_logits = torch.max(logits, dim=-1, keepdim=True)[0]
32+
exp_logits = torch.exp(logits - max_logits)
33+
sum_exp = torch.sum(exp_logits, dim=-1, keepdim=True)
34+
probs = exp_logits / sum_exp # [chunk_size, V]
35+
36+
# Compute loss using gather
37+
correct_logits = torch.gather(
38+
logits, 1, labels_chunk.unsqueeze(1)
39+
) # [chunk_size, 1]
40+
out_loss += torch.sum(
41+
max_logits.squeeze()
42+
+ torch.log(sum_exp.squeeze())
43+
- correct_logits.squeeze()
44+
)
45+
46+
# Compute gradients
47+
dprobs = probs.clone() # [chunk_size, V]
48+
dprobs.scatter_(
49+
1,
50+
labels_chunk.unsqueeze(1),
51+
dprobs.gather(1, labels_chunk.unsqueeze(1)) - 1,
52+
)
53+
54+
# Accumulate gradients
55+
weight_grad += dprobs.T @ act_chunk # [H, V]
56+
act_grad[b:end_idx] = dprobs @ weight # [chunk_size, H]
57+
58+
# Scale gradients
59+
scale = 1.0 / bs
60+
weight_grad *= scale
61+
act_grad *= scale
62+
63+
ctx.save_for_backward(weight_grad, act_grad)
64+
return scale * out_loss
65+
66+
@staticmethod
67+
def backward(ctx, grad_output): # type: ignore
68+
69+
(
70+
weight_grad,
71+
act_grad,
72+
) = ctx.saved_tensors
73+
return grad_output * weight_grad, grad_output * act_grad, None
74+
75+
76+
# torch.compile does a good enough job with the kernel here
77+
@torch.compile
78+
def fused_cross_entropy(lm_head_weight, act, labels):
79+
return CrossEntropyLoopedFused.apply(lm_head_weight, act, labels)

hawk/hawk.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from .cache import RNNCache
1212
from .external import BlockDiagonalLinear, Conv1D, rnn_param_init
13+
from .fused_cross_entropy import fused_cross_entropy
1314
from .scan_fused import fused_linear_scan
1415

1516
# ------
@@ -289,21 +290,21 @@ def forward(
289290

290291
hidden_states = self.norm(hidden_states)
291292

292-
logits = self.lm_head(hidden_states)
293-
294293
if self.use_cache:
294+
logits = self.lm_head(hidden_states)
295295
return logits, rnn_rnn_cache_list
296296

297297
else:
298298
if labels is not None:
299-
shift_logits = logits[..., :-1, :].contiguous()
300299

301300
shift_labels = labels[..., 1:].contiguous()
302-
303-
loss_fct = torch.nn.CrossEntropyLoss()
304-
loss = loss_fct(
305-
shift_logits.view(-1, shift_logits.size(-1)),
301+
shift_x = hidden_states[..., :-1, :].contiguous()
302+
loss = fused_cross_entropy(
303+
self.lm_head.weight,
304+
shift_x.view(-1, shift_x.size(-1)),
306305
shift_labels.view(-1),
307306
)
308307
return loss
308+
309+
logits = self.lm_head(hidden_states)
309310
return logits

0 commit comments

Comments
 (0)