|
1 | 1 | import torch
|
| 2 | +import triton |
| 3 | +import triton.language as tl |
2 | 4 |
|
| 5 | +# LayerNorm adapted from triton tutorial, used for Cohere q, k norm |
| 6 | +# X [N, head_num, head_dim] |
| 7 | +# W [head_num, head_dim] |
| 8 | +@triton.jit |
| 9 | +def _layer_norm_fwd_kernel( |
| 10 | + X, # pointer to the input |
| 11 | + W, # pointer to the weights |
| 12 | + Y, |
| 13 | + stride_x_N, |
| 14 | + stride_x_hn, |
| 15 | + stride_x_hd, |
| 16 | + stride_y_N, |
| 17 | + stride_y_hn, |
| 18 | + stride_y_hd, |
| 19 | + stride_w_hn, |
| 20 | + stride_w_hd, |
| 21 | + N, # number of columns in X |
| 22 | + eps, # epsilon to avoid division by zero |
| 23 | + BLOCK_SIZE: tl.constexpr, |
| 24 | +): |
| 25 | + Seq = tl.program_id(0) |
| 26 | + H = tl.program_id(1) |
3 | 27 |
|
4 |
| -def layernorm_forward(x, weight, eps): |
5 |
| - return torch.layer_norm(x, (x.shape[-1],), weight, bias=None, eps=eps) |
| 28 | + X += Seq * stride_x_N + H * stride_x_hn |
| 29 | + Y += Seq * stride_y_N + H * stride_y_hn |
| 30 | + W += H * stride_w_hn |
6 | 31 |
|
| 32 | + _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32) |
| 33 | + for off in range(0, N, BLOCK_SIZE): |
| 34 | + cols = off + tl.arange(0, BLOCK_SIZE) |
| 35 | + a = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) |
| 36 | + _mean += a |
| 37 | + mean = tl.sum(_mean, axis=0) / N |
7 | 38 |
|
8 |
| -def multi_head_layernorm_forward(x, weight, eps): |
| 39 | + _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) |
| 40 | + for off in range(0, N, BLOCK_SIZE): |
| 41 | + cols = off + tl.arange(0, BLOCK_SIZE) |
| 42 | + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) |
| 43 | + x = tl.where(cols < N, x - mean, 0.0) |
| 44 | + _var += x * x |
| 45 | + var = tl.sum(_var, axis=0) / N |
| 46 | + rstd = 1 / tl.sqrt(var + eps) |
| 47 | + |
| 48 | + for off in range(0, N, BLOCK_SIZE): |
| 49 | + cols = off + tl.arange(0, BLOCK_SIZE) |
| 50 | + mask = cols < N |
| 51 | + w = tl.load(W + cols, mask=mask).to(tl.float32) |
| 52 | + x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32) |
| 53 | + x_hat = (x - mean) * rstd |
| 54 | + y = x_hat * w |
| 55 | + |
| 56 | + tl.store(Y + cols, y.to(X.dtype.element_ty), mask=mask) |
| 57 | + |
| 58 | + |
| 59 | +def layernorm_forward( |
| 60 | + X, # pointer to the input |
| 61 | + W, # pointer to the weights |
| 62 | + eps, # epsilon to avoid division by zero |
| 63 | +): |
| 64 | + assert len(X.shape) == 3 |
| 65 | + assert len(W.shape) == 2 |
| 66 | + assert X.shape[-1] == W.shape[-1] |
| 67 | + assert X.shape[-2] == W.shape[-2] |
| 68 | + |
| 69 | + y = torch.empty_like(X) |
| 70 | + |
| 71 | + stride_x_N = X.stride(0) |
| 72 | + stride_x_hn = X.stride(1) |
| 73 | + stride_x_hd = X.stride(2) |
| 74 | + |
| 75 | + stride_y_N = y.stride(0) |
| 76 | + stride_y_hn = y.stride(1) |
| 77 | + stride_y_hd = y.stride(2) |
| 78 | + |
| 79 | + stride_w_hn = W.stride(0) |
| 80 | + stride_w_hd = W.stride(1) |
| 81 | + |
| 82 | + N = X.shape[-1] |
| 83 | + BLOCK_SIZE = 128 |
| 84 | + |
| 85 | + grid = (X.shape[0], X.shape[1]) |
| 86 | + _layer_norm_fwd_kernel[grid]( |
| 87 | + X, |
| 88 | + W, |
| 89 | + y, |
| 90 | + stride_x_N, |
| 91 | + stride_x_hn, |
| 92 | + stride_x_hd, |
| 93 | + stride_y_N, |
| 94 | + stride_y_hn, |
| 95 | + stride_y_hd, |
| 96 | + stride_w_hn, |
| 97 | + stride_w_hd, |
| 98 | + N, |
| 99 | + eps, |
| 100 | + BLOCK_SIZE, |
| 101 | + ) |
| 102 | + |
| 103 | + return y |
| 104 | + |
| 105 | + |
| 106 | +def torch_layernorm(x, weight, eps): |
9 | 107 | inp_dtype = x.dtype
|
10 | 108 | x = x.to(torch.float32)
|
11 | 109 | mean = x.mean(-1, keepdim=True)
|
12 | 110 | variance = (x - mean).pow(2).mean(-1, keepdim=True)
|
13 | 111 | x = (x - mean) * torch.rsqrt(variance + eps)
|
14 | 112 | x = weight.to(torch.float32) * x
|
15 | 113 | return x.to(inp_dtype)
|
| 114 | + |
| 115 | + |
| 116 | +def test_layernorm(eps=1e-5): |
| 117 | + # create data |
| 118 | + dtype = torch.float16 |
| 119 | + x_shape = (5, 1, 128) |
| 120 | + w_shape = (x_shape[-2], x_shape[-1]) |
| 121 | + weight = torch.rand(w_shape, dtype=dtype, device="cuda") |
| 122 | + x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda") |
| 123 | + # forward pass |
| 124 | + y_ref = torch_layernorm(x, weight, eps).to(dtype) |
| 125 | + y_out = layernorm_forward(x, weight, eps) |
| 126 | + |
| 127 | + # compare |
| 128 | + print("type:", y_out.dtype, y_ref.dtype) |
| 129 | + print("max delta:", torch.max(torch.abs(y_out - y_ref))) |
| 130 | + assert torch.allclose(y_out, y_ref, atol=1e-2, rtol=0) |
| 131 | + return |
0 commit comments