-
Notifications
You must be signed in to change notification settings - Fork 468
[NPU]: optimize GEGLU implementation with flatten 1D approach #1031
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
noemotiovon
wants to merge
1
commit into
linkedin:main
Choose a base branch
from
noemotiovon:op_geglu
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+121
−195
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,266 +1,191 @@ | ||
| """ | ||
| UB-aware GEGLU implementation for Ascend NPU. | ||
| This implementation automatically adjusts block sizes to fit within UB constraints, | ||
| preventing UB overflow errors when running on Ascend NPU. | ||
| It reuses the original kernels when possible, and only uses tiling when necessary. | ||
| """ | ||
|
|
||
| import operator | ||
|
|
||
| import torch | ||
| import triton | ||
| import triton.language as tl | ||
|
|
||
| from triton.language.math import tanh | ||
|
|
||
| from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy | ||
| from liger_kernel.ops.utils import calculate_settings | ||
| from liger_kernel.ops.utils import compare_version | ||
| from liger_kernel.ops.utils import ensure_contiguous | ||
| from liger_kernel.utils import is_npu_available | ||
|
|
||
| if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available(): | ||
| try: | ||
| from triton.language.extra.libdevice import tanh | ||
| except ModuleNotFoundError: | ||
| from triton.language.extra.cuda.libdevice import tanh | ||
| else: | ||
| from triton.language.math import tanh | ||
| from liger_kernel.ops.utils import get_npu_core_count | ||
|
|
||
|
|
||
| @triton.jit | ||
| def _geglu_tanh_forward_kernel_npu(a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr): | ||
| def _geglu_forward_kernel_flat(a_ptr, b_ptr, c_ptr, total_elements, BLOCK_SIZE: tl.constexpr, NUM_STAGES: tl.constexpr): | ||
| """ | ||
| UB-aware GEGLU forward kernel for NPU. | ||
| High-performance GEGLU forward kernel using flatten 1D approach. | ||
| Uses tiling loop to handle cases where BLOCK_SIZE < n_cols (due to UB constraints). | ||
| When BLOCK_SIZE >= n_cols, the loop executes only once, maintaining original behavior. | ||
| Uses grid-stride loop pattern for optimal performance on NPU. | ||
| """ | ||
| program_id = tl.program_id(0).to(tl.int64) | ||
| pid = tl.program_id(0) | ||
| num_progs = tl.num_programs(0) | ||
|
|
||
| # Grid-Stride Loop | ||
| start_idx = pid * BLOCK_SIZE | ||
| stride = num_progs * BLOCK_SIZE | ||
|
|
||
| # locate start index | ||
| a += program_id * stride | ||
| b += program_id * stride | ||
| c += program_id * stride | ||
| # Constants for GELU tanh approximation | ||
| sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi) | ||
| gelu_coeff = 0.044715 | ||
|
|
||
| # Process in tiles when BLOCK_SIZE < n_cols | ||
| for i in range(0, n_cols, BLOCK_SIZE): | ||
| col_offsets = i + tl.arange(0, BLOCK_SIZE) | ||
| mask = col_offsets < n_cols | ||
| for idx in tl.range(start_idx, total_elements, stride, num_stages=NUM_STAGES): | ||
| offsets = idx + tl.arange(0, BLOCK_SIZE) | ||
| mask = offsets < total_elements | ||
|
|
||
| a_row = tl.load(a + col_offsets, mask=mask, other=0).to(tl.float32) | ||
| b_row = tl.load(b + col_offsets, mask=mask, other=0) | ||
| a_val = tl.load(a_ptr + offsets, mask=mask, other=0.0).to(tl.float32) | ||
| b_val = tl.load(b_ptr + offsets, mask=mask, other=0.0) | ||
|
|
||
| # tanh approximation form of GELU is computed with: | ||
| # 0.5 * a * (1 + tanh(sqrt(2 / pi) * (a + 0.044715 * a^3))) | ||
| sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi) | ||
| a_cubed = a_row * a_row * a_row | ||
| tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed) | ||
| a_cubed = a_val * a_val * a_val | ||
| tanh_arg = sqrt_2_over_pi * (a_val + gelu_coeff * a_cubed) | ||
| tanh_result = tanh(tanh_arg) | ||
| geglu_a = 0.5 * a_row * (1 + tanh_result) | ||
| c_row = geglu_a.cast(b_row.dtype) * b_row | ||
|
|
||
| tl.store(c + col_offsets, c_row, mask=mask) | ||
| geglu_a = 0.5 * a_val * (1.0 + tanh_result) | ||
| c_row = geglu_a.cast(b_val.dtype) * b_val | ||
| tl.store(c_ptr + offsets, c_row, mask=mask) | ||
|
|
||
|
|
||
| @triton.jit | ||
| def _geglu_tanh_backward_kernel_npu(dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr): | ||
| def _geglu_backward_kernel_flat( | ||
| dc_ptr, a_ptr, b_ptr, da_ptr, db_ptr, total_elements, BLOCK_SIZE: tl.constexpr, NUM_STAGES: tl.constexpr | ||
| ): | ||
| """ | ||
| UB-aware GEGLU backward kernel for NPU. | ||
| High-performance GEGLU backward kernel using flatten 1D approach. | ||
| Uses tiling loop to handle cases where BLOCK_SIZE < n_cols (due to UB constraints). | ||
| When BLOCK_SIZE >= n_cols, the loop executes only once, maintaining original behavior. | ||
| Uses grid-stride loop pattern for optimal performance on NPU. | ||
| """ | ||
| program_id = tl.program_id(0).to(tl.int64) | ||
| pid = tl.program_id(0) | ||
| num_progs = tl.num_programs(0) | ||
| start_idx = pid * BLOCK_SIZE | ||
| stride = num_progs * BLOCK_SIZE | ||
|
|
||
| # locate start index | ||
| dc += program_id * stride | ||
| a += program_id * stride | ||
| b += program_id * stride | ||
| # Constants for GELU tanh approximation | ||
| sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi) | ||
| gelu_coeff = 0.044715 | ||
|
|
||
| # Process in tiles when BLOCK_SIZE < n_cols | ||
| for i in range(0, n_cols, BLOCK_SIZE): | ||
| col_offsets = i + tl.arange(0, BLOCK_SIZE) | ||
| mask = col_offsets < n_cols | ||
| for idx in tl.range(start_idx, total_elements, stride, num_stages=NUM_STAGES): | ||
| offsets = idx + tl.arange(0, BLOCK_SIZE) | ||
| mask = offsets < total_elements | ||
|
|
||
| dc_row = tl.load(dc + col_offsets, mask=mask, other=0) | ||
| a_row = tl.load(a + col_offsets, mask=mask, other=0).to(tl.float32) | ||
| b_row = tl.load(b + col_offsets, mask=mask, other=0) | ||
| dc = tl.load(dc_ptr + offsets, mask=mask, other=0.0) | ||
| a = tl.load(a_ptr + offsets, mask=mask, other=0.0).to(tl.float32) | ||
| b = tl.load(b_ptr + offsets, mask=mask, other=0.0) | ||
|
|
||
| # recomputation to save memory | ||
| sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi) | ||
| a_cubed = a_row * a_row * a_row | ||
| tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed) | ||
| a_cubed = a * a * a | ||
| tanh_arg = sqrt_2_over_pi * (a + gelu_coeff * a_cubed) | ||
| tanh_result = tanh(tanh_arg) | ||
| geglu_a = 0.5 * a_row * (1 + tanh_result) | ||
| geglu_a = geglu_a.to(dc_row.dtype).to(tl.float32) | ||
| geglu_a = 0.5 * a * (1 + tanh_result) | ||
| geglu_a = geglu_a.to(dc.dtype).to(tl.float32) | ||
|
|
||
| db_row = dc_row.cast(tl.float32) * geglu_a | ||
| db = dc.cast(tl.float32) * geglu_a | ||
|
|
||
| # Gradient w.r.t. a can be computed with: | ||
| # b * (0.5 * (1 + tanh(z)) + 0.5 * a * (1 - tanh(z)^2) * (sqrt(2/pi) * (1 + 3 * 0.044715 * a^2))) | ||
| # where z = sqrt(2/pi) * (a + 0.044715 * a^3) | ||
| term1 = 0.5 * (1 + tanh_result) | ||
| term1 = 0.5 * (1.0 + tanh_result) | ||
| tanh_sq = tanh_result * tanh_result | ||
| term2 = 0.5 * a_row * (1 - tanh_sq) * (sqrt_2_over_pi * (1 + 3 * 0.044715 * a_row * a_row)) | ||
| da_row = dc_row * b_row * (term1 + term2) | ||
| a_sq = a * a | ||
| term2 = 0.5 * a * (1.0 - tanh_sq) * (sqrt_2_over_pi * (1.0 + 3.0 * gelu_coeff * a_sq)) | ||
| da = dc * b * (term1 + term2) | ||
|
|
||
| tl.store(a + col_offsets, da_row, mask=mask) | ||
| tl.store(b + col_offsets, db_row.to(dc_row.dtype), mask=mask) | ||
| tl.store(da_ptr + offsets, da, mask=mask) | ||
| tl.store(db_ptr + offsets, db.to(dc.dtype), mask=mask) | ||
|
|
||
|
|
||
| def geglu_forward(a, b): | ||
| def get_optimal_block_size(total_elements, is_backward=False): | ||
| """ | ||
| UB-aware GEGLU forward pass for NPU. | ||
| Calculate optimal Block Size using compute_default_tiling_strategy. | ||
| Automatically adjusts block size to fit within UB constraints. | ||
| """ | ||
| ori_shape = a.shape | ||
| Args: | ||
| total_elements: Total number of elements to process | ||
| is_backward: Whether this is for backward pass (requires more memory) | ||
| n_cols = ori_shape[-1] | ||
| a = a.view(-1, n_cols) | ||
| b = b.view(-1, n_cols) | ||
| c = torch.empty_like(a) | ||
| n_rows = a.shape[0] | ||
|
|
||
| # Calculate desired block size | ||
| desired_block_size, num_warps = calculate_settings(n_cols) | ||
|
|
||
| # Compute tiling strategy based on UB capacity | ||
| dtype_size = a.element_size() | ||
| # GEGLU forward tiling strategy: | ||
| # - Calculates maximum safe block size based on UB capacity | ||
| # - Memory analysis (only buffers that occupy UB, excluding temporary variables): | ||
| # * Inputs: a_row (4 bytes, float32), b_row (dtype_size bytes) | ||
| # * Output: c_row (dtype_size bytes) | ||
| # * Temporary variables (a_cubed, tanh_arg, tanh_result, geglu_a) are optimized to registers | ||
| # and don't occupy UB since they are only used once | ||
| # * For float16: a_row(4) + b_row(2) + c_row(2) = 8 bytes/element, ratio = 8/2 = 4.0 | ||
| # * For float32: a_row(4) + b_row(4) + c_row(4) = 12 bytes/element, ratio = 12/4 = 3.0 | ||
| # - Uses memory_multiplier=4.0 (float16) or 3.0 (float32) * BLOCK_SIZE * dtype_size * 8 bits | ||
| # - shapes: ((n_cols,),) | ||
| # - tiling_dims: (0,) means first dimension can be tiled | ||
| # - Returns: ((block_size,),) | ||
| shapes = ((n_cols,),) | ||
| if dtype_size == 2: | ||
| memory_multiplier = 4.0 | ||
| Returns: | ||
| Optimal block size for the kernel | ||
| """ | ||
| # Memory multiplier based on peak memory usage analysis | ||
| if is_backward: | ||
| memory_multiplier = 6.0 | ||
| else: | ||
| memory_multiplier = 3.0 | ||
| # Call calculation function | ||
| # Treat input as 1D (total_elements,), only tiling on dim 0 | ||
| tile_shapes = compute_default_tiling_strategy( | ||
| safety_margin=0.80, | ||
| dtype_size=dtype_size, | ||
| safety_margin=0.9, | ||
| dtype_size=4, | ||
| memory_multiplier=memory_multiplier, | ||
| shapes=shapes, | ||
| shapes=((total_elements,),), | ||
| tiling_dims=(0,), | ||
| ) | ||
|
|
||
| if tile_shapes is not None and len(tile_shapes) > 0 and len(tile_shapes[0]) > 0: | ||
| # Strategy returns ((block_size,),) | ||
| adjusted_block_size = tile_shapes[0][0] | ||
| # Parse result | ||
| if tile_shapes and len(tile_shapes) > 0: | ||
| block_size = tile_shapes[0][0] | ||
| return max(256, block_size) | ||
| else: | ||
| # Fallback to desired block size if no best practice found (no tiling needed) | ||
| adjusted_block_size = desired_block_size | ||
| # Always use the unified NPU kernel | ||
| # When adjusted_block_size >= n_cols, the loop executes only once (no tiling) | ||
| # When adjusted_block_size < n_cols, the loop handles tiling automatically | ||
| _geglu_tanh_forward_kernel_npu[(n_rows,)]( | ||
| a, | ||
| b, | ||
| c, | ||
| c.stride(-2), | ||
| n_cols=n_cols, | ||
| BLOCK_SIZE=adjusted_block_size, | ||
| num_warps=num_warps, | ||
| ) | ||
| return a, b, c.view(*ori_shape) | ||
| return 2048 | ||
|
|
||
|
|
||
| def geglu_backward(a, b, dc): | ||
| def geglu_forward(a, b): | ||
| """ | ||
| UB-aware GEGLU backward pass for NPU. | ||
| High-performance GEGLU forward pass for NPU using flatten 1D approach. | ||
| """ | ||
| if not a.is_contiguous(): | ||
| a = a.contiguous() | ||
| if not b.is_contiguous(): | ||
| b = b.contiguous() | ||
|
|
||
| total_elements = a.numel() | ||
| c = torch.empty_like(a) | ||
|
|
||
| block_size = get_optimal_block_size(total_elements, is_backward=False) | ||
|
|
||
| num_cores = get_npu_core_count() | ||
| grid_size = min(num_cores, (total_elements + block_size - 1) // block_size) | ||
|
|
||
| _geglu_forward_kernel_flat[(grid_size,)](a, b, c, total_elements, BLOCK_SIZE=block_size, NUM_STAGES=3, num_warps=4) | ||
| return c | ||
|
|
||
|
|
||
| Automatically adjusts block size to fit within UB constraints. | ||
| def geglu_backward(a, b, dc): | ||
| """ | ||
| ori_shape = dc.shape | ||
| n_cols = ori_shape[-1] | ||
| dc = dc.view(-1, n_cols) | ||
| n_rows = dc.shape[0] | ||
|
|
||
| # Calculate desired block size | ||
| desired_block_size, num_warps = calculate_settings(n_cols) | ||
|
|
||
| # Compute tiling strategy based on UB capacity | ||
| dtype_size = dc.element_size() | ||
| # GEGLU backward tiling strategy: | ||
| # - Calculates maximum safe block size based on UB capacity | ||
| # - Memory analysis: Peak memory usage occurs when executing line 103 (term1 calculation) | ||
| # At this point, the following buffers simultaneously occupy UB: | ||
| # 1. dc_row = tl.load(dc + col_offsets, ...) # dtype_size bytes | ||
| # 2. a_row = tl.load(a + col_offsets, ...).to(tl.float32) # 4 bytes (float32) | ||
| # 3. b_row = tl.load(b + col_offsets, ...) # dtype_size bytes | ||
| # 4. tanh_result = tanh(tanh_arg) # 4 bytes (float32), used in lines 95, 103, 104 | ||
| # 5. geglu_a = 0.5 * a_row * (1 + tanh_result) # 4 bytes (float32), used in lines 96, 98 | ||
| # 6. db_row = dc_row.cast(tl.float32) * geglu_a # 4 bytes (float32, computed at line 98, stored at line 109) | ||
| # Note: term1 (line 103) is a temporary variable optimized to registers and doesn't occupy UB | ||
| # Temporary variables (a_cubed, tanh_arg, term1, tanh_sq, term2) are optimized to registers | ||
| # and don't occupy UB since they are only used once | ||
| # * For float16: dc_row(2) + a_row(4) + b_row(2) + tanh_result(4) + geglu_a(4) + db_row(4) | ||
| # = 20 bytes/element, ratio = 20/2 = 10.0 | ||
| # * For float32: dc_row(4) + a_row(4) + b_row(4) + tanh_result(4) + geglu_a(4) + db_row(4) | ||
| # = 24 bytes/element, ratio = 24/4 = 6.0 | ||
| # - Uses memory_multiplier=10.0 (float16) or 6.0 (float32) * BLOCK_SIZE * dtype_size * 8 bits | ||
| # - shapes: ((n_cols,),) | ||
| # - tiling_dims: (0,) means first dimension can be tiled | ||
| # - Returns: ((block_size,),) | ||
| shapes = ((n_cols,),) | ||
| if dtype_size == 2: | ||
| memory_multiplier = 10.0 | ||
| else: | ||
| memory_multiplier = 6.0 | ||
| tile_shapes = compute_default_tiling_strategy( | ||
| safety_margin=0.80, | ||
| dtype_size=dtype_size, | ||
| memory_multiplier=memory_multiplier, | ||
| shapes=shapes, | ||
| tiling_dims=(0,), | ||
| ) | ||
| High-performance GEGLU backward pass for NPU using flatten 1D approach. | ||
| """ | ||
| if not dc.is_contiguous(): | ||
| dc = dc.contiguous() | ||
| if not a.is_contiguous(): | ||
| a = a.contiguous() | ||
| if not b.is_contiguous(): | ||
| b = b.contiguous() | ||
|
|
||
| if tile_shapes is not None and len(tile_shapes) > 0 and len(tile_shapes[0]) > 0: | ||
| # Strategy returns ((block_size,),) | ||
| adjusted_block_size = tile_shapes[0][0] | ||
| else: | ||
| # Fallback to desired block size if no best practice found (no tiling needed) | ||
| adjusted_block_size = desired_block_size | ||
|
|
||
| # Always use the unified NPU kernel | ||
| # When adjusted_block_size >= n_cols, the loop executes only once (no tiling) | ||
| # When adjusted_block_size < n_cols, the loop handles tiling automatically | ||
| _geglu_tanh_backward_kernel_npu[(n_rows,)]( | ||
| dc, | ||
| a, | ||
| b, | ||
| dc.stride(-2), | ||
| n_cols=n_cols, | ||
| BLOCK_SIZE=adjusted_block_size, | ||
| num_warps=num_warps, | ||
| ) | ||
| total_elements = dc.numel() | ||
| grad_a = torch.empty_like(a) | ||
| grad_b = torch.empty_like(b) | ||
|
|
||
| block_size = get_optimal_block_size(total_elements, is_backward=True) | ||
|
|
||
| return a.view(*ori_shape), b.view(*ori_shape) | ||
| num_cores = get_npu_core_count() | ||
| grid_size = min(num_cores, (total_elements + block_size - 1) // block_size) | ||
|
|
||
| _geglu_backward_kernel_flat[(grid_size,)]( | ||
| dc, a, b, grad_a, grad_b, total_elements, BLOCK_SIZE=block_size, NUM_STAGES=3, num_warps=4 | ||
| ) | ||
| return grad_a, grad_b | ||
|
|
||
|
|
||
| class LigerGELUMulFunction(torch.autograd.Function): | ||
| """UB-aware GEGLU function for Ascend NPU.""" | ||
| """High-performance GEGLU function for Ascend NPU.""" | ||
|
|
||
| @staticmethod | ||
| @ensure_contiguous | ||
| def forward(ctx, a, b): | ||
| a, b, c = geglu_forward(a, b) | ||
| c = geglu_forward(a, b) | ||
| ctx.save_for_backward(a, b) | ||
| return c | ||
|
|
||
| @staticmethod | ||
| @ensure_contiguous | ||
| def backward(ctx, dc): | ||
| a, b = ctx.saved_tensors | ||
| a, b = geglu_backward(a, b, dc) | ||
| return a, b | ||
| grad_a, grad_b = geglu_backward(a, b, dc) | ||
| return grad_a, grad_b |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you know what tensor couldn't pass with this tolerance? gradients or inputs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the question. I double-checked which tensors require the large tolerance.
On NPU with bfloat16:
So the forward and weight gradients are already numerically different at ~1e2, and the input gradients further amplify this difference.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also worth noting: the tolerances used here are consistent with the previous NPU GEGLU kernel implementation, so this change does not introduce new numerical error compared to the existing behavior on NPU.