diff --git a/src/nki_samples/tutorials/sd_attention/sd_attention_nki_kernels.py b/src/nki_samples/tutorials/sd_attention/sd_attention_nki_kernels.py index 6d1f781..dd5509c 100644 --- a/src/nki_samples/tutorials/sd_attention/sd_attention_nki_kernels.py +++ b/src/nki_samples/tutorials/sd_attention/sd_attention_nki_kernels.py @@ -15,7 +15,7 @@ # NKI_EXAMPLE_31_BEGIN @nki.jit def fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, use_causal_mask=False, - mixed_percision=True): + mixed_precision=True): """ Fused self attention kernel for small head dimension Stable Diffusion workload, simplified for this tutorial. @@ -38,14 +38,14 @@ def fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, use_causal_mask= IO tensor dtypes: - This kernel assumes all IO tensors have the same dtype - - If mixed_percision is True, then all Tensor Engine operation will be performed in + - If mixed_precision is True, then all Tensor Engine operation will be performed in bfloat16 and accumulation will be performed in float32. Otherwise the intermediates will be in the same type as the inputs. """ # Use q_ref dtype as the intermediate tensor dtype # Assume all IO tensors have the same dtype kernel_dtype = q_ref.dtype - pe_in_dt = nl.bfloat16 if mixed_percision else np.float32 + pe_in_dt = nl.bfloat16 if mixed_precision else np.float32 assert q_ref.dtype == k_ref.dtype == v_ref.dtype # Shape checking diff --git a/src/nki_samples/tutorials/tensor_addition/spmd_multiple_nc_tensor_addition_jax.py b/src/nki_samples/tutorials/tensor_addition/spmd_multiple_nc_tensor_addition_jax.py new file mode 100644 index 0000000..8e48d88 --- /dev/null +++ b/src/nki_samples/tutorials/tensor_addition/spmd_multiple_nc_tensor_addition_jax.py @@ -0,0 +1,34 @@ +""" +Copyright (C) 2024, Amazon.com. All Rights Reserved + +JAX implementation for SPMD tensor addition with multiple Neuron cores NKI tutorial. + +""" +# NKI_EXAMPLE_50_BEGIN +import jax +import jax.numpy as jnp +# NKI_EXAMPLE_50_END + +from spmd_multiple_nc_tensor_addition_nki_kernels import nki_tensor_add_nc2 + +# NKI_EXAMPLE_50_BEGIN +if __name__ == "__main__": + + seed_a, seed_b = jax.random.split(jax.random.PRNGKey(42)) + a = jax.random.uniform(seed_a, (512, 2048), dtype=jnp.bfloat16) + b = jax.random.uniform(seed_b, (512, 2048), dtype=jnp.bfloat16) + + output_nki = nki_tensor_add_nc2(a, b) + print(f"output_nki={output_nki}") + + output_jax = a + b + print(f"output_jax={output_jax}") + + allclose = jnp.allclose(output_jax, output_nki, atol=1e-4, rtol=1e-2) + if allclose: + print("NKI and JAX match") + else: + print("NKI and JAX differ") + + assert allclose + # NKI_EXAMPLE_50_END diff --git a/src/nki_samples/tutorials/tensor_addition/spmd_multiple_nc_tensor_addition_nki_kernels.py b/src/nki_samples/tutorials/tensor_addition/spmd_multiple_nc_tensor_addition_nki_kernels.py new file mode 100644 index 0000000..1b52ca9 --- /dev/null +++ b/src/nki_samples/tutorials/tensor_addition/spmd_multiple_nc_tensor_addition_nki_kernels.py @@ -0,0 +1,61 @@ +""" +Copyright (C) 2024, Amazon.com. All Rights Reserved + +NKI implementation for SPMD tensor addition with multiple Neuron cores NKI tutorial. + +""" +import numpy as np +import neuronxcc.nki as nki +import neuronxcc.nki.language as nl +from spmd_tensor_addition_nki_kernels import nki_tensor_add_kernel_ + + +# NKI_EXAMPLE_48_BEGIN +def nki_tensor_add_nc2(a_input, b_input): + """NKI kernel caller to compute element-wise addition of two input tensors using multiple Neuron cores. + + This kernel caller lifts tile-size restriction, by applying the kernel on tiles of the inputs/outputs. + a_input and b_input are sharded across Neuron cores, directly utilizing Trn2 architecture capabilities + + Args: + a_input: a first input tensor, of shape [N*128, M*512] + b_input: a second input tensor, of shape [N*128, M*512] + + Returns: + a tensor of shape [N*128, M*512], the result of a_input + b_input + """ + + # The SPMD launch grid denotes the number of kernel instances. + # In this case, we use a 2D grid where the size of each invocation is 128x512 + # Since we're sharding across neuron cores on the 1st dimension we want to do our slicing at + # 128 per core * 2 cores = 256 + grid_x = a_input.shape[0] // (128 * 2) + grid_y = a_input.shape[1] // 512 + + # In addition, we distribute the kernel to physical neuron cores around the first dimension + # of the spmd grid. + # This means: + # Physical NC [0]: kernel[n, m] where n is even + # Physical NC [1]: kernel[n, m] where n is odd + # notice, by specifying this information in the SPMD grid, we can use multiple neuron cores + # without updating the original `nki_tensor_add_kernel_` kernel. + return nki_tensor_add_kernel_[nl.spmd_dim(grid_x, nl.nc(2)), grid_y](a_input, b_input) + # NKI_EXAMPLE_48_END + +if __name__ == "__main__": + a = np.random.rand(512, 2048).astype(np.float16) + b = np.random.rand(512, 2048).astype(np.float16) + + output_nki = nki_tensor_add_nc2(a, b) + print(f"output_nki={output_nki}") + + output_np = a + b + print(f"output_np={output_np}") + + allclose = np.allclose(output_np, output_nki, atol=1e-4, rtol=1e-2) + if allclose: + print("NKI and NumPy match") + else: + print("NKI and NumPy differ") + + assert allclose diff --git a/src/nki_samples/tutorials/tensor_addition/spmd_multiple_nc_tensor_addition_torch.py b/src/nki_samples/tutorials/tensor_addition/spmd_multiple_nc_tensor_addition_torch.py new file mode 100644 index 0000000..9ec3695 --- /dev/null +++ b/src/nki_samples/tutorials/tensor_addition/spmd_multiple_nc_tensor_addition_torch.py @@ -0,0 +1,35 @@ +""" +Copyright (C) 2024, Amazon.com. All Rights Reserved + +PyTorch implementation for SPMD tensor addition with multiple Neuron cores NKI tutorial. + +""" +# NKI_EXAMPLE_49_BEGIN +import torch +from torch_xla.core import xla_model as xm +# NKI_EXAMPLE_49_END + +from spmd_multiple_nc_tensor_addition_nki_kernels import nki_tensor_add_nc2 + + +# NKI_EXAMPLE_49_BEGIN +if __name__ == "__main__": + device = xm.xla_device() + + a = torch.rand((512, 2048), dtype=torch.bfloat16).to(device=device) + b = torch.rand((512, 2048), dtype=torch.bfloat16).to(device=device) + + output_nki = nki_tensor_add_nc2(a, b) + print(f"output_nki={output_nki}") + + output_torch = a + b + print(f"output_torch={output_torch}") + + allclose = torch.allclose(output_torch, output_nki, atol=1e-4, rtol=1e-2) + if allclose: + print("NKI and Torch match") + else: + print("NKI and Torch differ") + + assert allclose + # NKI_EXAMPLE_49_END diff --git a/src/nki_samples/tutorials/tensor_addition/spmd_tensor_addition_jax.py b/src/nki_samples/tutorials/tensor_addition/spmd_tensor_addition_jax.py new file mode 100644 index 0000000..c1f566f --- /dev/null +++ b/src/nki_samples/tutorials/tensor_addition/spmd_tensor_addition_jax.py @@ -0,0 +1,34 @@ +""" +Copyright (C) 2024, Amazon.com. All Rights Reserved + +JAX implementation for SPMD tensor addition NKI tutorial. + +""" +# NKI_EXAMPLE_30_BEGIN +import jax +import jax.numpy as jnp +# NKI_EXAMPLE_30_END + +from spmd_tensor_addition_nki_kernels import nki_tensor_add + +# NKI_EXAMPLE_30_BEGIN +if __name__ == "__main__": + + seed_a, seed_b = jax.random.split(jax.random.PRNGKey(42)) + a = jax.random.uniform(seed_a, (256, 1024), dtype=jnp.bfloat16) + b = jax.random.uniform(seed_b, (256, 1024), dtype=jnp.bfloat16) + + output_nki = nki_tensor_add(a, b) + print(f"output_nki={output_nki}") + + output_jax = a + b + print(f"output_jax={output_jax}") + + allclose = jnp.allclose(output_jax, output_nki, atol=1e-4, rtol=1e-2) + if allclose: + print("NKI and JAX match") + else: + print("NKI and JAX differ") + + assert allclose + # NKI_EXAMPLE_30_END diff --git a/src/nki_samples/tutorials/tensor_addition/spmd_tensor_addition_nki_kernels.py b/src/nki_samples/tutorials/tensor_addition/spmd_tensor_addition_nki_kernels.py new file mode 100644 index 0000000..508d5c4 --- /dev/null +++ b/src/nki_samples/tutorials/tensor_addition/spmd_tensor_addition_nki_kernels.py @@ -0,0 +1,91 @@ +""" +Copyright (C) 2024, Amazon.com. All Rights Reserved + +NKI implementation for SPMD tensor addition NKI tutorial. + +""" +import numpy as np +# NKI_EXAMPLE_27_BEGIN +import neuronxcc.nki as nki +import neuronxcc.nki.language as nl + + +@nki.jit +def nki_tensor_add_kernel_(a_input, b_input): + """NKI kernel to compute element-wise addition of two input tensors + + This kernel assumes strict input/output sizes can be uniformly tiled to [128,512] + + Args: + a_input: a first input tensor + b_input: a second input tensor + + Returns: + c_output: an output tensor + """ + # Create output tensor shared between all SPMD instances as result tensor + c_output = nl.ndarray(a_input.shape, dtype=a_input.dtype, buffer=nl.shared_hbm) + + # Calculate tile offsets based on current 'program' + offset_i_x = nl.program_id(0) * 128 + offset_i_y = nl.program_id(1) * 512 + + # Generate tensor indices to index tensors a and b + ix = offset_i_x + nl.arange(128)[:, None] + iy = offset_i_y + nl.arange(512)[None, :] + + # Load input data from device memory (HBM) to on-chip memory (SBUF) + # We refer to an indexed portion of a tensor as an intermediate tensor + a_tile = nl.load(a_input[ix, iy]) + b_tile = nl.load(b_input[ix, iy]) + + # compute a + b + c_tile = a_tile + b_tile + + # store the addition results back to device memory (c_output) + nl.store(c_output[ix, iy], value=c_tile) + + # Transfer the ownership of `c_output` to the caller + return c_output + # NKI_EXAMPLE_27_END + + +# NKI_EXAMPLE_28_BEGIN +def nki_tensor_add(a_input, b_input): + """NKI kernel caller to compute element-wise addition of two input tensors + + This kernel caller lifts tile-size restriction, by applying the kernel on tiles of the inputs/outputs + + Args: + a_input: a first input tensor, of shape [N*128, M*512] + b_input: a second input tensor, of shape [N*128, M*512] + + Returns: + a tensor of shape [N*128, M*512], the result of a_input + b_input + """ + + # The SPMD launch grid denotes the number of kernel instances. + # In this case, we use a 2D grid where the size of each invocation is 128x512 + grid_x = a_input.shape[0] // 128 + grid_y = a_input.shape[1] // 512 + + return nki_tensor_add_kernel_[grid_x, grid_y](a_input, b_input) + # NKI_EXAMPLE_28_END + +if __name__ == "__main__": + a = np.random.rand(256, 1024).astype(np.float16) + b = np.random.rand(256, 1024).astype(np.float16) + + output_nki = nki_tensor_add(a, b) + print(f"output_nki={output_nki}") + + output_np = a + b + print(f"output_np={output_np}") + + allclose = np.allclose(output_np, output_nki, atol=1e-4, rtol=1e-2) + if allclose: + print("NKI and NumPy match") + else: + print("NKI and NumPy differ") + + assert allclose diff --git a/src/nki_samples/tutorials/tensor_addition/spmd_tensor_addition_torch.py b/src/nki_samples/tutorials/tensor_addition/spmd_tensor_addition_torch.py new file mode 100644 index 0000000..8df0f3b --- /dev/null +++ b/src/nki_samples/tutorials/tensor_addition/spmd_tensor_addition_torch.py @@ -0,0 +1,35 @@ +""" +Copyright (C) 2024, Amazon.com. All Rights Reserved + +PyTorch implementation for SPMD tensor addition NKI tutorial. + +""" +# NKI_EXAMPLE_29_BEGIN +import torch +from torch_xla.core import xla_model as xm +# NKI_EXAMPLE_29_END + +from spmd_tensor_addition_nki_kernels import nki_tensor_add + + +# NKI_EXAMPLE_29_BEGIN +if __name__ == "__main__": + device = xm.xla_device() + + a = torch.rand((256, 1024), dtype=torch.bfloat16).to(device=device) + b = torch.rand((256, 1024), dtype=torch.bfloat16).to(device=device) + + output_nki = nki_tensor_add(a, b) + print(f"output_nki={output_nki}") + + output_torch = a + b + print(f"output_torch={output_torch}") + + allclose = torch.allclose(output_torch, output_nki, atol=1e-4, rtol=1e-2) + if allclose: + print("NKI and Torch match") + else: + print("NKI and Torch differ") + + assert allclose + # NKI_EXAMPLE_29_END