Skip to content

Commit

Permalink
[BugFix] Propagate mask correctly in allocated_fused_rms_norm_qkv kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
aws-qieqingy authored and aws-zhehongb committed Feb 5, 2025
1 parent c72a66b commit fc202b2
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions src/nki_samples/reference/allocated_fused_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def allocated_fused_rms_norm_qkv(hidden, weights, norm_dtype=nl.float32, eps=1e-
seqlen, dim = batchless_shape
_dim, head_dim = weights.shape

assert dim <= 8192 and dim & 128 == 0, "Unsupported hidden dimension"
assert dim <= 8192 and dim % 128 == 0, "Unsupported hidden dimension"
assert _dim == dim, "Reduction dimension must match"
assert head_dim <= 512, "Head dimension must be 512 or less"

Expand Down Expand Up @@ -67,7 +67,7 @@ def allocated_fused_rms_norm_qkv(hidden, weights, norm_dtype=nl.float32, eps=1e-
# Double buffer the input tensor
in_bufs = nl.ndarray((2, par_dim(pmax), dim), dtype=hidden.dtype, buffer=ncc.sbuf.mod_alloc(base_addr=260, num_free_tiles=(2,)))
for i_interleave_grp in nl.affine_range(2):
in_bufs[i_interleave_grp] = nl.load(hidden[b, (2*i+i_interleave_grp)*pmax+ix, iy], mask=(2*i+i_interleave_grp)*pmax+ix < seqlen)
in_bufs[i_interleave_grp] = nl.load(hidden[b, (2*i+i_interleave_grp)*pmax+ix, iy], mask=((2*i + i_interleave_grp) * pmax < seqlen))
act = nl.ndarray((par_dim(pmax), dim), dtype=norm_dtype, buffer=ncc.sbuf.mod_alloc(base_addr=260+(2*dim)*2))

# Write the RMS and RMS Reciprocal tensors back out here, in-place
Expand All @@ -81,9 +81,9 @@ def allocated_fused_rms_norm_qkv(hidden, weights, norm_dtype=nl.float32, eps=1e-
output_sbuf = nl.ndarray((par_dim(pmax), fmax), dtype=weights.dtype,
buffer=ncc.sbuf.mod_alloc(base_addr=260+(3*dim)*2+(dim+1)*4))

act[...] = nisa.activation_reduce(op=nl.square, data=in_bufs[i_interleave_grp], reduce_op=np.add, reduce_res=square_sum[...], bias=bias_placeholder[...])
square_sum[...] = nisa.tensor_scalar(square_sum[...], np.multiply, scale, op1=np.add, operand1=eps)
square_sum[...] = nisa.activation(op=nl.rsqrt, data=square_sum[...], bias=bias_placeholder[...])
act[...] = nisa.activation_reduce(op=nl.square, data=in_bufs[i_interleave_grp], reduce_op=np.add, reduce_res=square_sum[...], bias=bias_placeholder[...], mask=((2*i + i_interleave_grp) * pmax < seqlen))
square_sum[...] = nisa.tensor_scalar(square_sum[...], np.multiply, scale, op1=np.add, operand1=eps, mask=((2*i + i_interleave_grp) * pmax < seqlen))
square_sum[...] = nisa.activation(op=nl.rsqrt, data=square_sum[...], bias=bias_placeholder[...], mask=((2*i + i_interleave_grp) * pmax < seqlen))

# all PE array ops must output to FP32 on trn1 but must match input dtype in trn2
if nisa.get_nc_version() == nisa.nc_version.gen3:
Expand All @@ -95,20 +95,20 @@ def allocated_fused_rms_norm_qkv(hidden, weights, norm_dtype=nl.float32, eps=1e-

for m in nl.affine_range(NUM_TRANSP_TILES):
# Perform (hidden .* RMS Reciprocal)^T in tiles of fmax (512)
out_tile[i_rhs.p, m*fmax+i_rhs.x] = nl.multiply(in_bufs[i_interleave_grp, i_rhs.p, m*fmax + i_rhs.x], square_sum[...], dtype=weights.dtype)
out_tile[i_rhs.p, m*fmax+i_rhs.x] = nl.multiply(in_bufs[i_interleave_grp, i_rhs.p, m*fmax + i_rhs.x], square_sum[...], dtype=weights.dtype, mask=((2*i + i_interleave_grp) * pmax < seqlen))
for j in nl.affine_range(4):
transpose_res_psum[m, i_lhs.p, j*pmax+i_lhs.x] = nisa.nc_matmul(out_tile[i_lhs.p, (m*4+j) * pmax + i_lhs.x], identity_tensor[...],
is_transpose=True)
out_tile[i_rhs.p, m * 4*pmax + i_rhs.x] = nl.copy(transpose_res_psum[m], dtype=hidden.dtype)
is_transpose=True, mask=((2*i + i_interleave_grp) * pmax < seqlen))
out_tile[i_rhs.p, m * 4*pmax + i_rhs.x] = nl.copy(transpose_res_psum[m], dtype=hidden.dtype, mask=((2*i + i_interleave_grp) * pmax < seqlen))

# perform (RMSNorm(hidden)^T)^T @ wQKV
res_psum = nl.ndarray((1, par_dim(pmax), fmax), dtype=nl.float32,
buffer=ncc.psum.mod_alloc(base_bank=7, num_bank_tiles=(1,)))
for m in nl.affine_range(M):
res_psum[0] += nisa.nc_matmul(out_tile[i_lhs.p, m*pmax+i_lhs.x], weights_buffer[m, i_rhs.p, i_rhs.x])
res_psum[0] += nisa.nc_matmul(out_tile[i_lhs.p, m*pmax+i_lhs.x], weights_buffer[m, i_rhs.p, i_rhs.x], mask=((2*i + i_interleave_grp) * pmax < seqlen))

output_sbuf[...] = nl.copy(res_psum[0], dtype=out_tensor.dtype)
nl.store(out_tensor[b, (2*i+i_interleave_grp)*pmax+i_res.p, i_res.x],
value=output_sbuf,
mask=((2*i+i_interleave_grp)*pmax+i_res.p<seqlen) & (i_res.x<head_dim))
mask=(i_res.x<head_dim) & ((2*i + i_interleave_grp) * pmax < seqlen))
return out_tensor

0 comments on commit fc202b2

Please sign in to comment.