|
| 1 | +"""This is a fused cross-entropy and linear layer. Idea is copied |
| 2 | +from https://github.com/linkedin/Liger-Kernel who just copied it from |
| 3 | +https://github.com/mgmalek/efficient_cross_entropy |
| 4 | +""" |
| 5 | + |
| 6 | +import torch |
| 7 | +from torch.autograd import Function |
| 8 | +from torch.nn import functional as F |
| 9 | + |
| 10 | + |
| 11 | +class CrossEntropyLoopedFused(Function): |
| 12 | + @staticmethod |
| 13 | + def forward(ctx, weight: torch.Tensor, act: torch.Tensor, labels: torch.Tensor): |
| 14 | + bs = act.shape[0] |
| 15 | + weight_grad = torch.zeros_like(weight) |
| 16 | + act_grad = torch.empty_like(act) |
| 17 | + out_loss = torch.tensor(0.0, device=act.device) |
| 18 | + chunksize = 2048 |
| 19 | + |
| 20 | + for b in range(0, bs, chunksize): |
| 21 | + end_idx = min(b + chunksize, bs) |
| 22 | + |
| 23 | + # Get current batch chunks |
| 24 | + act_chunk = act[b:end_idx] # [chunk_size, H] |
| 25 | + labels_chunk = labels[b:end_idx] # [chunk_size] |
| 26 | + |
| 27 | + # Compute logits |
| 28 | + logits = F.linear(act_chunk, weight) # [chunk_size, V] |
| 29 | + |
| 30 | + # Compute softmax and loss |
| 31 | + max_logits = torch.max(logits, dim=-1, keepdim=True)[0] |
| 32 | + exp_logits = torch.exp(logits - max_logits) |
| 33 | + sum_exp = torch.sum(exp_logits, dim=-1, keepdim=True) |
| 34 | + probs = exp_logits / sum_exp # [chunk_size, V] |
| 35 | + |
| 36 | + # Compute loss using gather |
| 37 | + correct_logits = torch.gather( |
| 38 | + logits, 1, labels_chunk.unsqueeze(1) |
| 39 | + ) # [chunk_size, 1] |
| 40 | + out_loss += torch.sum( |
| 41 | + max_logits.squeeze() |
| 42 | + + torch.log(sum_exp.squeeze()) |
| 43 | + - correct_logits.squeeze() |
| 44 | + ) |
| 45 | + |
| 46 | + # Compute gradients |
| 47 | + dprobs = probs.clone() # [chunk_size, V] |
| 48 | + dprobs.scatter_( |
| 49 | + 1, |
| 50 | + labels_chunk.unsqueeze(1), |
| 51 | + dprobs.gather(1, labels_chunk.unsqueeze(1)) - 1, |
| 52 | + ) |
| 53 | + |
| 54 | + # Accumulate gradients |
| 55 | + weight_grad += dprobs.T @ act_chunk # [H, V] |
| 56 | + act_grad[b:end_idx] = dprobs @ weight # [chunk_size, H] |
| 57 | + |
| 58 | + # Scale gradients |
| 59 | + scale = 1.0 / bs |
| 60 | + weight_grad *= scale |
| 61 | + act_grad *= scale |
| 62 | + |
| 63 | + ctx.save_for_backward(weight_grad, act_grad) |
| 64 | + return scale * out_loss |
| 65 | + |
| 66 | + @staticmethod |
| 67 | + def backward(ctx, grad_output): # type: ignore |
| 68 | + |
| 69 | + ( |
| 70 | + weight_grad, |
| 71 | + act_grad, |
| 72 | + ) = ctx.saved_tensors |
| 73 | + return grad_output * weight_grad, grad_output * act_grad, None |
| 74 | + |
| 75 | + |
| 76 | +# torch.compile does a good enough job with the kernel here |
| 77 | +@torch.compile |
| 78 | +def fused_cross_entropy(lm_head_weight, act, labels): |
| 79 | + return CrossEntropyLoopedFused.apply(lm_head_weight, act, labels) |
0 commit comments