Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update tutorials code in NeuronSDK 2.21 release #41

Merged
merged 1 commit into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# NKI_EXAMPLE_31_BEGIN
@nki.jit
def fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, use_causal_mask=False,
mixed_percision=True):
mixed_precision=True):
"""
Fused self attention kernel for small head dimension Stable Diffusion workload,
simplified for this tutorial.
Expand All @@ -38,14 +38,14 @@ def fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, use_causal_mask=

IO tensor dtypes:
- This kernel assumes all IO tensors have the same dtype
- If mixed_percision is True, then all Tensor Engine operation will be performed in
- If mixed_precision is True, then all Tensor Engine operation will be performed in
bfloat16 and accumulation will be performed in float32. Otherwise the intermediates
will be in the same type as the inputs.
"""
# Use q_ref dtype as the intermediate tensor dtype
# Assume all IO tensors have the same dtype
kernel_dtype = q_ref.dtype
pe_in_dt = nl.bfloat16 if mixed_percision else np.float32
pe_in_dt = nl.bfloat16 if mixed_precision else np.float32
assert q_ref.dtype == k_ref.dtype == v_ref.dtype

# Shape checking
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""
Copyright (C) 2024, Amazon.com. All Rights Reserved

JAX implementation for SPMD tensor addition with multiple Neuron cores NKI tutorial.

"""
# NKI_EXAMPLE_50_BEGIN
import jax
import jax.numpy as jnp
# NKI_EXAMPLE_50_END

from spmd_multiple_nc_tensor_addition_nki_kernels import nki_tensor_add_nc2

# NKI_EXAMPLE_50_BEGIN
if __name__ == "__main__":

seed_a, seed_b = jax.random.split(jax.random.PRNGKey(42))
a = jax.random.uniform(seed_a, (512, 2048), dtype=jnp.bfloat16)
b = jax.random.uniform(seed_b, (512, 2048), dtype=jnp.bfloat16)

output_nki = nki_tensor_add_nc2(a, b)
print(f"output_nki={output_nki}")

output_jax = a + b
print(f"output_jax={output_jax}")

allclose = jnp.allclose(output_jax, output_nki, atol=1e-4, rtol=1e-2)
if allclose:
print("NKI and JAX match")
else:
print("NKI and JAX differ")

assert allclose
# NKI_EXAMPLE_50_END
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""
Copyright (C) 2024, Amazon.com. All Rights Reserved

NKI implementation for SPMD tensor addition with multiple Neuron cores NKI tutorial.

"""
import numpy as np
import neuronxcc.nki as nki
import neuronxcc.nki.language as nl
from spmd_tensor_addition_nki_kernels import nki_tensor_add_kernel_


# NKI_EXAMPLE_48_BEGIN
def nki_tensor_add_nc2(a_input, b_input):
"""NKI kernel caller to compute element-wise addition of two input tensors using multiple Neuron cores.

This kernel caller lifts tile-size restriction, by applying the kernel on tiles of the inputs/outputs.
a_input and b_input are sharded across Neuron cores, directly utilizing Trn2 architecture capabilities

Args:
a_input: a first input tensor, of shape [N*128, M*512]
b_input: a second input tensor, of shape [N*128, M*512]

Returns:
a tensor of shape [N*128, M*512], the result of a_input + b_input
"""

# The SPMD launch grid denotes the number of kernel instances.
# In this case, we use a 2D grid where the size of each invocation is 128x512
# Since we're sharding across neuron cores on the 1st dimension we want to do our slicing at
# 128 per core * 2 cores = 256
grid_x = a_input.shape[0] // (128 * 2)
grid_y = a_input.shape[1] // 512

# In addition, we distribute the kernel to physical neuron cores around the first dimension
# of the spmd grid.
# This means:
# Physical NC [0]: kernel[n, m] where n is even
# Physical NC [1]: kernel[n, m] where n is odd
# notice, by specifying this information in the SPMD grid, we can use multiple neuron cores
# without updating the original `nki_tensor_add_kernel_` kernel.
return nki_tensor_add_kernel_[nl.spmd_dim(grid_x, nl.nc(2)), grid_y](a_input, b_input)
# NKI_EXAMPLE_48_END

if __name__ == "__main__":
a = np.random.rand(512, 2048).astype(np.float16)
b = np.random.rand(512, 2048).astype(np.float16)

output_nki = nki_tensor_add_nc2(a, b)
print(f"output_nki={output_nki}")

output_np = a + b
print(f"output_np={output_np}")

allclose = np.allclose(output_np, output_nki, atol=1e-4, rtol=1e-2)
if allclose:
print("NKI and NumPy match")
else:
print("NKI and NumPy differ")

assert allclose
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""
Copyright (C) 2024, Amazon.com. All Rights Reserved

PyTorch implementation for SPMD tensor addition with multiple Neuron cores NKI tutorial.

"""
# NKI_EXAMPLE_49_BEGIN
import torch
from torch_xla.core import xla_model as xm
# NKI_EXAMPLE_49_END

from spmd_multiple_nc_tensor_addition_nki_kernels import nki_tensor_add_nc2


# NKI_EXAMPLE_49_BEGIN
if __name__ == "__main__":
device = xm.xla_device()

a = torch.rand((512, 2048), dtype=torch.bfloat16).to(device=device)
b = torch.rand((512, 2048), dtype=torch.bfloat16).to(device=device)

output_nki = nki_tensor_add_nc2(a, b)
print(f"output_nki={output_nki}")

output_torch = a + b
print(f"output_torch={output_torch}")

allclose = torch.allclose(output_torch, output_nki, atol=1e-4, rtol=1e-2)
if allclose:
print("NKI and Torch match")
else:
print("NKI and Torch differ")

assert allclose
# NKI_EXAMPLE_49_END
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""
Copyright (C) 2024, Amazon.com. All Rights Reserved

JAX implementation for SPMD tensor addition NKI tutorial.

"""
# NKI_EXAMPLE_30_BEGIN
import jax
import jax.numpy as jnp
# NKI_EXAMPLE_30_END

from spmd_tensor_addition_nki_kernels import nki_tensor_add

# NKI_EXAMPLE_30_BEGIN
if __name__ == "__main__":

seed_a, seed_b = jax.random.split(jax.random.PRNGKey(42))
a = jax.random.uniform(seed_a, (256, 1024), dtype=jnp.bfloat16)
b = jax.random.uniform(seed_b, (256, 1024), dtype=jnp.bfloat16)

output_nki = nki_tensor_add(a, b)
print(f"output_nki={output_nki}")

output_jax = a + b
print(f"output_jax={output_jax}")

allclose = jnp.allclose(output_jax, output_nki, atol=1e-4, rtol=1e-2)
if allclose:
print("NKI and JAX match")
else:
print("NKI and JAX differ")

assert allclose
# NKI_EXAMPLE_30_END
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""
Copyright (C) 2024, Amazon.com. All Rights Reserved

NKI implementation for SPMD tensor addition NKI tutorial.

"""
import numpy as np
# NKI_EXAMPLE_27_BEGIN
import neuronxcc.nki as nki
import neuronxcc.nki.language as nl


@nki.jit
def nki_tensor_add_kernel_(a_input, b_input):
"""NKI kernel to compute element-wise addition of two input tensors

This kernel assumes strict input/output sizes can be uniformly tiled to [128,512]

Args:
a_input: a first input tensor
b_input: a second input tensor

Returns:
c_output: an output tensor
"""
# Create output tensor shared between all SPMD instances as result tensor
c_output = nl.ndarray(a_input.shape, dtype=a_input.dtype, buffer=nl.shared_hbm)

# Calculate tile offsets based on current 'program'
offset_i_x = nl.program_id(0) * 128
offset_i_y = nl.program_id(1) * 512

# Generate tensor indices to index tensors a and b
ix = offset_i_x + nl.arange(128)[:, None]
iy = offset_i_y + nl.arange(512)[None, :]

# Load input data from device memory (HBM) to on-chip memory (SBUF)
# We refer to an indexed portion of a tensor as an intermediate tensor
a_tile = nl.load(a_input[ix, iy])
b_tile = nl.load(b_input[ix, iy])

# compute a + b
c_tile = a_tile + b_tile

# store the addition results back to device memory (c_output)
nl.store(c_output[ix, iy], value=c_tile)

# Transfer the ownership of `c_output` to the caller
return c_output
# NKI_EXAMPLE_27_END


# NKI_EXAMPLE_28_BEGIN
def nki_tensor_add(a_input, b_input):
"""NKI kernel caller to compute element-wise addition of two input tensors

This kernel caller lifts tile-size restriction, by applying the kernel on tiles of the inputs/outputs

Args:
a_input: a first input tensor, of shape [N*128, M*512]
b_input: a second input tensor, of shape [N*128, M*512]

Returns:
a tensor of shape [N*128, M*512], the result of a_input + b_input
"""

# The SPMD launch grid denotes the number of kernel instances.
# In this case, we use a 2D grid where the size of each invocation is 128x512
grid_x = a_input.shape[0] // 128
grid_y = a_input.shape[1] // 512

return nki_tensor_add_kernel_[grid_x, grid_y](a_input, b_input)
# NKI_EXAMPLE_28_END

if __name__ == "__main__":
a = np.random.rand(256, 1024).astype(np.float16)
b = np.random.rand(256, 1024).astype(np.float16)

output_nki = nki_tensor_add(a, b)
print(f"output_nki={output_nki}")

output_np = a + b
print(f"output_np={output_np}")

allclose = np.allclose(output_np, output_nki, atol=1e-4, rtol=1e-2)
if allclose:
print("NKI and NumPy match")
else:
print("NKI and NumPy differ")

assert allclose
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""
Copyright (C) 2024, Amazon.com. All Rights Reserved

PyTorch implementation for SPMD tensor addition NKI tutorial.

"""
# NKI_EXAMPLE_29_BEGIN
import torch
from torch_xla.core import xla_model as xm
# NKI_EXAMPLE_29_END

from spmd_tensor_addition_nki_kernels import nki_tensor_add


# NKI_EXAMPLE_29_BEGIN
if __name__ == "__main__":
device = xm.xla_device()

a = torch.rand((256, 1024), dtype=torch.bfloat16).to(device=device)
b = torch.rand((256, 1024), dtype=torch.bfloat16).to(device=device)

output_nki = nki_tensor_add(a, b)
print(f"output_nki={output_nki}")

output_torch = a + b
print(f"output_torch={output_torch}")

allclose = torch.allclose(output_torch, output_nki, atol=1e-4, rtol=1e-2)
if allclose:
print("NKI and Torch match")
else:
print("NKI and Torch differ")

assert allclose
# NKI_EXAMPLE_29_END
Loading