diff --git a/.gitignore b/.gitignore index aaa3d00..2f5fb50 100644 --- a/.gitignore +++ b/.gitignore @@ -31,7 +31,7 @@ __pycache__/ # Distribution / packaging .Python -build/ +docs/build/ develop-eggs/ dist/ downloads/ @@ -68,5 +68,4 @@ pip-delete-this-directory.txt *.pyc *.json *.jsonl -*_ignore.py -.idea \ No newline at end of file +.idea diff --git a/README.md b/README.md index e7aedcd..08e3b3b 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ -
@@ -20,16 +20,16 @@
## *News or Update* 🔥
---
-- [2025/07] We implement multiple kernels for fp8 training with `Megatron-LM` blockwise quantization.
+- [2025/07] We implement multiple kernels for FP8 training with `Megatron-LM` blockwise quantization.
## Introduction
---
-Our repo, FLOPS, is designed for LLM training, especially for MoE training with fp8 quantizaiton. It provides 3 main categories of kernels:
+Our repo, linghe, is designed for LLM training, especially for MoE training with FP8 quantizaiton. It provides 2 main categories of kernels:
- **Fused quantization kernels**: fuse quantization with previous layer, e.g., RMS norm and Silu.
-- **Memory-friendly kernels**: use dtype cast in kernels instead of casting out kernels, e.g., softmax cross entropy and moe router gemm.
-- **Other fused kernels**: fuse multiple IO-itensive operations, e.g., ROPE with qk-norm and transpose, permute and padding, group RMS norm with sigmoid gate.
+- **Memory-efficiency kernels**: fuse multiple IO-itensive operations, e.g., ROPE with qk-norm.
+- **Implementation-optimized kernels**: use efficient triton implementation, e.g., routing map padding instead of activation padding.
## Benchmark
@@ -37,25 +37,26 @@ Our repo, FLOPS, is designed for LLM training, especially for MoE training with
We benchmark on H800 with batch size 8192, hidden size 2048, num experts 256, activation experts 8.
| kernel | baseline(us) | linghe(us) | speedup |
-|--------|--------------|-----------|---------|
-| RMSNorm+Quantization(forward) | 159.3 us | 72.4 us | 2.2 |
-| Split+qk-norm+rope+transpose(forward) | 472 us | 59.1 us | 7.99 |
-| Split+qk-norm+rope+transpose(backward) | 645 us | 107.5 us | 6.0 |
-| Fp32 router gemm(forward) | 242.3 us | 61.6 us | 3.931 |
-| Fp32 router gemm(backward) | 232.7 us | 78.1 us | 2.979 |
-| Permute with padded indices | 388 us | 229.4 us | 1.69 |
-| Unpermute with padding indices | 988.6 us | 806.9 us | 1.23 |
-| Batch Silu+quantization(forward) | 6241.7 us | 1181.7 us | 5.28 |
-| Batch Silu+quantization(backward) | 7147.7 us | 2317.9 us | 3.08 |
-| Silu+quantization(forward) | 144.9 us | 58.2 us | 2.48 |
-| Silu+quantization(backward) | 163.4 us | 74.2 us | 2.2 |
-| fused linear gate(forward) | 160.4 us | 46.9 us | 3.42 |
-| fused linear gate(backward) | 572.9 us | 81.1 us | 7.06 |
-| Cross entropy(forward) | 2780.8 us | 818.2 us | 3.4 |
-| Cross entropy(backward) | 7086.3 us | 1781.0 us | 3.98 |
-| batch grad norm | 1733.7 us | 1413.7 us | 1.23 |
-| Batch count zero | 4997.9 us | 746.8 us | 6.69 |
-
+|--------|--------------|------------|---------|
+| RMSNorm+Quantization(forward) | 159.3 us | 72.4 us | 2.2 |
+| Split+qk-norm+rope+transpose(forward) | 472 us | 59.1 us | 7.99 |
+| Split+qk-norm+rope+transpose(backward) | 645 us | 107.5 us | 6.0 |
+| Fp32 router gemm(forward) | 242.3 us | 61.6 us | 3.931 |
+| Fp32 router gemm(backward) | 232.7 us | 78.1 us | 2.979 |
+| Permute with padded indices | 388 us | 229.4 us | 1.69 |
+| Unpermute with padding indices | 988.6 us | 806.9 us | 1.23 |
+| Batch Silu+quantization(forward) | 6241.7 us | 1181.7 us | 5.28 |
+| Batch Silu+quantization(backward) | 7147.7 us | 2317.9 us | 3.08 |
+| Silu+quantization(forward) | 144.9 us | 58.2 us | 2.48 |
+| Silu+quantization(backward) | 163.4 us | 74.2 us | 2.2 |
+| fused linear gate(forward) | 160.4 us | 46.9 us | 3.42 |
+| fused linear gate(backward) | 572.9 us | 81.1 us | 7.06 |
+| Cross entropy(forward) | 2780.8 us | 818.2 us | 3.4 |
+| Cross entropy(backward) | 7086.3 us | 1781.0 us | 3.98 |
+| batch grad norm | 1733.7 us | 1413.7 us | 1.23 |
+| Batch count zero | 4997.9 us | 746.8 us | 6.69 |
+
+Other benchmark results can be obtained by running scripts in tests and benchmark folders.
## Examples
---
@@ -65,4 +66,4 @@ Examples can be found in tests.
## Api Reference
---
-Please refer to [API doc](docs/api.md)
\ No newline at end of file
+Please refer to [API](https://inclusionai.github.io/linghe/)
\ No newline at end of file
diff --git a/asserts/linghe.png b/asserts/linghe.png
new file mode 100644
index 0000000..c1778e4
Binary files /dev/null and b/asserts/linghe.png differ
diff --git a/build.sh b/build.sh
index 904a121..b123d1c 100644
--- a/build.sh
+++ b/build.sh
@@ -2,4 +2,6 @@ rm -rf build &&
rm -rf dist &&
rm -rf linghe.egg-info &&
python setup.py develop &&
-python setup.py bdist_wheel &&
+python setup.py bdist_wheel
+
+# pdoc --output-dir docs -d google --no-include-undocumented --no-search --no-show-source linghe
\ No newline at end of file
diff --git a/docs/api.md b/docs/api.md
deleted file mode 100644
index 000bfc2..0000000
--- a/docs/api.md
+++ /dev/null
@@ -1,212 +0,0 @@
-# API Reference
-
-
-```
-linghe.utils.norm.triton_rms_norm_and_block_quant_forward(x, weight, eps:Optional[float]=1e-6, out:Optional[torch.Tensor]=None, scale:Optional[torch.Tensor]=None, rms:Optional[torch.Tensor]=None, round_scale: Optional[bool]=False, output_mode:Optional[int]=2)
-```
-
-Computes the forward pass of RMSNorm and block quantization.
-
-**Parameters:**
-- x(*torch.Tensor*) - Input tensor. [M, N]
-- weight(*torch.Tensor*) - RMSNorm weight. [N]
-- eps(*float*) - epsilon value for L2 normalization.
-- round_scale(*bool*) - Set whether to force power of 2 scales.
-- rms(*torch.Tensor*) - Reciprocal of the root mean square of the input calculated over the last dimension.[N]
-- output_mode - (*int*, {0, 1, 2}, default = 2) 0 only output non-transpose tensor, 1 only output transposed tensor, 2 return both.
-
----
-
-**`
-Class linghe.facade.rope.QkNormHalfRopeFunction
-`**
-
-```
-forward(qkv:, q_norm_weight, k_norm_weight, freqs, H, h, eps:Optional[float]=1e-6)
-```
-Split qkv, and apply L2 nrom and ROPE on q and k.
-
-**Parameters:**
-- qkv(*torch.Tensor*) - QKV tensor with size of [S, B, dim]
-- freqs(*torch.Tensor*) - Freqs matrix based on half dim.
-- H(*int*) - Number of attention heads.
-- h(*int*) - Number of query groups.
-- eps(*float*) - epsilon value for L2 normalization.
-
-```
-backward(grad_q, grad_k, grad_v)
-```
-**Parameters:**
-- grad_q(*torch.Tensor*) Grad of q tensor.
-- grad_k(*torch.Tensor*) Grad of k tensor.
-- grad_v(*torch.Tensor*) Gard of v tensor.
-
----
-
-**`
-Class linghe.facade.fp32_linear.FusedFp32GEMM
-`**
-
-Optimized fp32 gemm in router gate function. Convert bf16 input and weight to float32 during the gemm operation.
-
-```
-forward(input, weight)
-```
-**Parameters:**
-- input(*torch.Tensor*) - Input tensor with [B, S, dim], dtype of bf16.
-- weight(*torch.Tensor*) - Weight tensor of router.
-
-```
-backward(grad_output)
-```
-**Parameters:**
-- grad_output(*torch.Tensor*) - Gradient of the activation.
-
----
-
-```
-linghe.utils.gather.triton_permute_with_mask_map(inp, scale, probs, row_id_map, num_out_tokens, contiguous, tokens_per_expert)
-```
-Permute the tokens and probs based on the routing map. Index indicates row index of the output tensor(-1 means not selected). Perform well even when inp.size(0) < expert padding number, do not need extra explict padding.
-
-**Parameters:**
-- inp(*torch.Tensor*) - Input hidden.[num_tokens, hidden_size]
-- scale(*torch.Tensor*) - [num_tokens, scale_size]
-- prob(*torch.Tensor*) - [num_tokens] Router prob.
-- row_id_map(*torch.Tensor*) - [n_experts, num_tokens] Index indicates row index of the output tensor.
-- num_out_tokens(*int*) - Output token count, including padding tokens.
-- contiguous(*bool*) - Whether indices in row_id_map is contiguous, should be False if padded.
-- token_per_expert(bool) - [num_experts] Token count per expert, non-blocking cuda tensor.
-
----
-
-```
-linghe.utils.scatter.triton_unpermute_with_mask_map(grad, row_id_map, probs)
-```
-Unpermute a tensor with permuted tokens with router mapping.
-
-**Parameters:**
-- inp(*torch.Tensor*) - [num_tokens, hidden_size] Permuted tokens.
-- row_id_map(*torch.Tensor*) - [n_experts, num_tokens] Routing map to unpermute the tokens.
-- prob(*torch.Tensor*) - [num_out_tokens] Permuted probs.
-
----
-
-```
-linghe.util.silu.triton_silu_and_block_quant_forward(x, out:Optional[torch.Tensor]=None, scale:Optional[torch.Tensor]=None, round_scale:Optional[bool]=False, output_mode:Optional[int]=2)
-```
-
-Applies the forward pass of Sigmoid Linear Unit(SiLU) element-wise and block quant.(used in shared expert layers.)
-
-**Parameters:**
-- x(*torch.Tensor*) - Input tensor to be quanted.
-- round_scale(*bool*) - Set whether to force power of 2 scales.
-- output_mode - (*int*, {0, 1, 2}, default = 2) 0 only output non-transpose tensor, 1 only output transposed tensor, 2 return both.
-
----
-
-```
-linghe.util.silu.triton_silu_and_block_quant_backward(g, x, round_scale:Optional[bool]=False)
-```
-**Parameters:**
-- g(*torch.Tensor*) - Gradient tensor to be quanted.
-- x(*torch.Tensor*) - Input tensor.
-- round_scale(*bool*) - Set whether to force power of 2 scales. Default to False.
-
----
-
-```
-linghe.util.silu.triton_batch_weighted_silu_and_block_quant_forward(x, weight, counts, splits:Optional[List]=None ,out:Optional[torch.Tensor]=None, scale:Optional[torch.Tensor]=None, round_scale:Optional[bool]=False, output_mode:Optional[int]=2)
-```
-
-Fused op for batched weighted SiLU and block quant.
-
-**Parameters:**
-- x(*torch.Tensor*) - Input tensor.
-- weight(*torch.Tensor*) - Permuted probs
-- couts(*torch.Tensor*) - Tokens per expert cuda tensor.
-- splits(*List[int]*) - List of tokens per expert. If compute in batch mode should not be None.
-- output_mode - (*int*, {0, 1, 2}, default = 2) 0 only output non-transpose tensor, 1 only output transposed tensor, 2 return both.
-
----
-
-```
-linghe.util.silu.triton_batch_weighted_silu_and_block_quant_backward(g, x, weight, counts, splits:Optional[List]=None, round_scale:Optional[bool]=False)
-```
-**Parameters:**
-- g(*torch.Tensor*) - Input gradient tensor.
-- x(*torch.Tensor*) - Input tensor.
-- weight(*torch.Tensor*) - Permuted probs
-- couts(*torch.Tensor*) - Tokens per expert cuda tensor.
-- splits(*List[int]*) - List of tokens per expert. If compute in batch mode should not be None.
-
----
-
-**`
-Class linghe.facade.loss.SoftmaxCrossEntropyFunction
-`**
-
-Prallel version of SoftmaxCrossEntropy.
-
-```
-forward(logits, labels, inplace:Optional[bool]=False)
-```
-
-Fast impl of softmax cross entropy.
-
-**Parameters:**
-- logits(*torch.Tensor*) - Input logits.
-- labels(*torch.Tensor*) - Input labels.
-- inplace(*bool*) - Flag save for backward, whether logits ptr should replaced by grads tensor ptr.
-
-```
-backward(grad_output)
-```
-
-**Parameters:**
-- grad_output(*torch.Tensor*) - Gradients tensor.
-
----
-
-```
-linghe.util.reduce.triton_batch_sum_with_ord(xs, ord:Optional[int]=2)
-```
-Square sum the gards of all the experts. All the experts grads are applied simultaneously.
-
-**Parameters:**
-- xs(*List[torch.Tensor]*) - Grads lists.
-- ord(*int*) - Sum type. 1 for abs add and 2 for square add.
-
----
-
-```
-linghe.util.reduce.triton_batch_count_zero(xs)
-```
-Prallel cout zeros in all the given grads lists.
-
-**Parameters:**
-- xs(*List[torch.Tensor]*) - Grads lists.
-
----
-
-**`
-Class linghe.facade.norm.GroupNormGateFunction
-`**
-Fused operation of group RMSNorm and sigmoid gate function.
-
-```
-forward(x, gate, weight, eps:Optional[float]=1e-6, group_size:Optional[int]=4)
-```
-Note that the output shape is transposed [S, B, dim]
-
-**Parameters:**
-
-- x(*torch.Tensor*) - [B, S, dim] Input tensor.
-- gate(*torch.Tensor*) - [S, B, dim]
-- weight(*torch.Tensor*) - [dim]
-
-```
-backward(grad)
-```
-**Parameters:**
-- grad(*torch.Tensor*) - [S, B, dim] Grads of input tensor.
diff --git a/docs/index.html b/docs/index.html
new file mode 100644
index 0000000..514a509
--- /dev/null
+++ b/docs/index.html
@@ -0,0 +1,7 @@
+
+
+
+
+
+
+
diff --git a/docs/linghe.html b/docs/linghe.html
new file mode 100644
index 0000000..a7cef32
--- /dev/null
+++ b/docs/linghe.html
@@ -0,0 +1,52 @@
+
+
+
+
+
+
+ Copyright (c) Ant Financial Service Group and its affiliates.
+inplace add y to x with mix precise
+ +++return updated x tensor
+
Copyright (c) Ant Financial Service Group and its affiliates.
+gemm with bf16/fp16 inputs and float32 output, +currently used in MoE router gemm.
+ +++output of gemm
+
Copyright (c) Ant Financial Service Group and its affiliates.
+gemm with bf16/fp16 inputs and float32 output, +currently used in MoE router gemm.
+Copyright (c) Ant Financial Service Group and its affiliates.
+Base class for all neural network modules.
+ +Your models should also subclass this class.
+ +Modules can also contain other Modules, allowing them to be nested in +a tree structure. You can assign the submodules as regular attributes::
+ +import torch.nn as nn
+import torch.nn.functional as F
+
+
+class Model(nn.Module):
+ def __init__(self) -> None:
+ super().__init__()
+ self.conv1 = nn.Conv2d(1, 20, 5)
+ self.conv2 = nn.Conv2d(20, 20, 5)
+
+ def forward(self, x):
+ x = F.relu(self.conv1(x))
+ return F.relu(self.conv2(x))
+
+
+Submodules assigned in this way will be registered, and will also have their
+parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
+must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or + evaluation mode. +:vartype training: bool
+a naive implementation of hadamard transformation and quantization
+ +Define the computation performed at every call.
+ +Should be overridden by all subclasses.
+ +Although the recipe for forward pass needs to be defined within
+this function, one should call the Module instance afterwards
+instead of this since the former takes care of running the
+registered hooks while the latter silently ignores them.
Copyright (c) Ant Financial Service Group and its affiliates.
+softmax cross entropy
+ +logits tensor if True++per token loss
+
Copyright (c) Ant Financial Service Group and its affiliates.
+rms norm of x with weight
+ +++rms output
+
return group_rms_norm(transpose(attn_output, [0,1]), weight) * sigmoid(gate)
+ +++output with shape [length, bs, dim]
+
Copyright (c) Ant Financial Service Group and its affiliates.
+split qkv to q/k/v, apply qk norm and half rope to q/k, transpose q/k/v to flash-attention layout
+ +++qo: shape [B, S, H, head_dim] + ko: shape [B, S, h, head_dim] + vo: shape [B, S, h, head_dim]
+
Copyright (c) Ant Financial Service Group and its affiliates.
+Base class for all neural network modules.
+ +Your models should also subclass this class.
+ +Modules can also contain other Modules, allowing them to be nested in +a tree structure. You can assign the submodules as regular attributes::
+ +import torch.nn as nn
+import torch.nn.functional as F
+
+
+class Model(nn.Module):
+ def __init__(self) -> None:
+ super().__init__()
+ self.conv1 = nn.Conv2d(1, 20, 5)
+ self.conv2 = nn.Conv2d(20, 20, 5)
+
+ def forward(self, x):
+ x = F.relu(self.conv1(x))
+ return F.relu(self.conv2(x))
+
+
+Submodules assigned in this way will be registered, and will also have their
+parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
+must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or + evaluation mode. +:vartype training: bool
+Initialize internal Module state, shared by both nn.Module and ScriptModule.
+Define the computation performed at every call.
+ +Should be overridden by all subclasses.
+ +Although the recipe for forward pass needs to be defined within
+this function, one should call the Module instance afterwards
+instead of this since the former takes care of running the
+registered hooks while the latter silently ignores them.
Copyright (c) Ant Financial Service Group and its affiliates.
+transpose a tensor with the first two dims, x.ndims should not greater than 4
+ +++a transposed tensor
+
Copyright (c) Ant Financial Service Group and its affiliates.
+Copyright (c) Ant Financial Service Group and its affiliates.
+similar to torch._scaled_mm, support accumulating gemm output to c + and low precision output tensor
+ +++c: output tensor
+
Copyright (c) Ant Financial Service Group and its affiliates.
+return fp32 gemm result with fp16/bf16 inputs, + it's mainly used for MoE router GEMM + and DO NOT suitable for large size GEMM
+ +++c: output with fp32 precision
+
mix precision gemm for backward, a@b.float()
+ +++c: gradient of activation
+
mix precision gemm for updaing weight
+ +++c: gradient of weight
+
c = (ascale[:,None])b +this kernel is used to fuse RMSNorm and quantization in MoE layer +native implementation: + y = rms_norm(x), + y_q = quantization(y), + router_logits = y@w +we can not fuse rms_norm and quantization +as we still need bf16 y for moe router gemm +fused implementation: + y_q, rms = quantization(rms_norm(x)) + router_logits = (x/rms)@y +so we need a scaled fp32 gemm kernel
+ +Returns:
+see triton_scaled_fp32_gemm
+ +++dw
+
Copyright (c) Ant Financial Service Group and its affiliates.
+blockwise quantize x
+ +++y: quantized tensor, float8_e4m3fn + s: quantization scale, float32
+
Copyright (c) Ant Financial Service Group and its affiliates.
+Copyright (c) Ant Financial Service Group and its affiliates.
+Copyright (c) Ant Financial Service Group and its affiliates.
+rowwise quantize x
+ +++x_q: quantized tensor + x_scale: quantization scale
+
rowwise quantize x with power of 2 dim size
+ +++out: quantized tensor + scale: quantization scale
+
transpose x and row quantize x
+ +++x_q: quantized tensor + x_scale: quantization scale
+
Copyright (c) Ant Financial Service Group and its affiliates.
+Copyright (c) Ant Financial Service Group and its affiliates.
+groupwise quantize x, group is in under rowwise format
+ +++y: quantized tensor, float8_e4m3fn + s: quantization scale, float32
+
Copyright (c) Ant Financial Service Group and its affiliates.
+apply hadamard transformation and then quantize transformed tensor
+ +++x_q: rowwise quantized tensor of non-transposed x + x_scale: rowwise quantization scale of non-transposed x + xt_q: columnwise quantized tensor of transposed x + xt_scale: columnwise quantization scale of transposed x
+
Copyright (c) Ant Financial Service Group and its affiliates.
+Copyright (c) Ant Financial Service Group and its affiliates.
+Copyright (c) Ant Financial Service Group and its affiliates.
+Copyright (c) Ant Financial Service Group and its affiliates.
+Copyright (c) Ant Financial Service Group and its affiliates.
+inplace add y to x
+ +++updated x
+
Copyright (c) Ant Financial Service Group and its affiliates.
+vector dot multiply, output = sum(x*y, 1), +it is used to calculate gradient of router weight
+ +++output of sum(x*y, 1)
+
Copyright (c) Ant Financial Service Group and its affiliates.
+make row id map, values in the tensor are the row indices
+ +++row id map with shape [n_tokens, n_experts]
+
similar with triton_make_row_id_map, but output an indices tensor as well
+ +++row_in_map: [n_tokens, n_experts] + row_indices: [num_out_tokens]
+
index select for quantized tensor
+ +++out: output of selected x + scale_out: scale of selected scale
+
gather quantized tensor with row id map
+ +++output: permuted quantized tensor + permuted_scale: permuted quantization scale + permuted_probs: permuted router prob
+
used for smooth quantization backward in megatron 0.12, +x is gathered, requantized, padded to multiple of 32 and tranposed
+ +++x_q: [sum(roundup(tokens_per_experts)) * dim] + x_scale: [sum(roundup(tokens_per_experts))]
+
select and smooth and quant, used in megatron 0.11 all2all moe
+ +++x_q: [bs*topk, dim] + x_scale: [bstopk] + x_sum: [bstopk]
+
select and smooth and quant
+ +Returns:
+gather and optional dequant and smooth quant
+ +Returns:
+Copyright (c) Ant Financial Service Group and its affiliates.
+compute token-wise softmax cross entropy loss
+ +++loss of each token
+
backward of softmax cross entropy loss
+ +++output_grad: [bs, dim]
+
rms norm
+ +++out: output tensor
+
Fused RMSNorm forward and block quantization.
+ +++out: quantization data + scale: quantization scale + rms: Reciprocal of the root mean square of the input calculated over the last dimension. + transpose_output: quantization data of transposed gradient + transpose_scale: quantization scale of transposed gradient
+
norm and gate in linear attention
+ +++output tensor
+
Copyright (c) Ant Financial Service Group and its affiliates.
+split x to multiple tensors and cat with indices, +it is used for permutation in moe
+ +++y: output tensor + output_scales: output scales if scales is not None
+
Copyright (c) Ant Financial Service Group and its affiliates.
+columnwise abs max of x, it is used in smooth quantization
+ +++max tensor
+
count zero in tensor list, it is used to monitor zeros in gradient tensor
+ +++a single-value int64 tensor
+
return sum(abs(x)**ord).
+ +++a single-value fp32 tensor
+
Copyright (c) Ant Financial Service Group and its affiliates.
+apply norm to qk, then apply half rope to qk
+ +++qo: + ko:
+
split qkv to q/k/v, apply qk norm and half rope to q/k, + transpose q/k/v to flash-attention layout
+ +++qo: shape [B, S, H, head_dim] + ko: shape [B, S, h, head_dim] + vo: shape [B, S, h, head_dim]
+
backward kernel of triton_qk_norm_and_half_rope_forward
+ +++dqkv: gradient of qkv + dqw: gradient of q_norm_weight + dkw: gradient of k_norm_weight
+
Copyright (c) Ant Financial Service Group and its affiliates.
+scatter_add for megatron 0.11
+ +++output tensor
+
naive version of scatter add, very slow
+ +++outputs
+
scatter add with row id map
+ +++output: [num_tokens, hidden_size] + restore_probs: [num_tokens, num_experts]
+
Copyright (c) Ant Financial Service Group and its affiliates.
+compute silu(x)*weight, used in bf16/fp16 training with MoE
+ +++out: output tensor
+
backward of triton_weighted_silu_forward
+ +++dx: gradient of x + dw: gradient of weight
+
fused silu and blockwise quantization, used in shared expert
+ +++out: quantized tensor + scale: quantization scale + transpose_output: quantized tensor of transposed output + transpose_scale: quantization scale of transposed output
+
backward of triton_silu_and_block_quant_forward
+ +++dx: quantized non-transposed gradient + dx_scale: scales of quantization non-transposed gradient + transpose_dx: quantized transposed gradient + transpose_dx_scale: scales of quantization transposed gradient
+
silu and blockwise quantize activation in routed experts
+ +++out: quantized tensor + scale: quantization scale + transpose_output: quantized tensor of transposed output + transpose_scale: quantization scale of transposed output
+
backward of triton_batch_weighted_silu_and_block_quant_forward
+ +++dx: quantized non-transposed gradient + dx_scale: scales of quantization non-transposed gradient + dw: gradient of weight + transpose_dx: quantized transposed gradient + transpose_dx_scale: scales of quantization transposed gradient
+
Copyright (c) Ant Financial Service Group and its affiliates.
+transpose x with dim0 and dim1
+ +++transposed tensor
+
transpose x and padding the column size to be mutiplier of 32, +it is used for calculated gradient of weight with torch._scaled__mm
+ +++out: output tensor
+
batch transpose x
+ +++xts: output tensor list, [N,M]*expert
+
transpose and pad each tensor stored in x
+ +++x_t: output tensor
+
Copyright (c) Ant Financial Service Group and its affiliates.
\n"}, {"fullname": "linghe.facade.add.InplaceAddFunction", "modulename": "linghe.facade.add", "qualname": "InplaceAddFunction", "kind": "class", "doc": "Base class to create custom autograd.Function.
To create a custom autograd.Function, subclass this class and implement\nthe :meth:forward and :meth:backward static methods. Then, to use your custom\nop in the forward pass, call the class method apply. Do not call\n:meth:forward directly.
To ensure correctness and best performance, make sure you are calling the\ncorrect methods on ctx and validating your backward function using\n:func:torch.autograd.gradcheck.
See :ref:extending-autograd for more details on how to use this class.
Examples::
\n\n>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)\n>>> class Exp(Function):\n>>> @staticmethod\n>>> def forward(ctx, i):\n>>> result = i.exp()\n>>> ctx.save_for_backward(result)\n>>> return result\n>>>\n>>> @staticmethod\n>>> def backward(ctx, grad_output):\n>>> result, = ctx.saved_tensors\n>>> return grad_output * result\n>>>\n>>> # Use it by calling the apply method:\n>>> # xdoctest: +SKIP\n>>> output = Exp.apply(input)\n\n", "bases": "torch.autograd.function.Function"}, {"fullname": "linghe.facade.add.InplaceAddFunction.forward", "modulename": "linghe.facade.add", "qualname": "InplaceAddFunction.forward", "kind": "function", "doc": "Define the forward of the custom autograd Function.
\n\nThis function is to be overridden by all subclasses.\nThere are two ways to define forward:
\n\nUsage 1 (Combined forward and ctx)::
\n\n@staticmethod\ndef forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:\n pass\n\n\ncombining-forward-context for more detailsUsage 2 (Separate forward and ctx)::
\n\n@staticmethod\ndef forward(*args: Any, **kwargs: Any) -> Any:\n pass\n\n@staticmethod\ndef setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:\n pass\n\n\ntorch.autograd.Function.setup_context\nstaticmethod to handle setting up the ctx object.\noutput is the output of the forward, inputs are a Tuple of inputs\nto the forward.extending-autograd for more detailsThe context can be used to store arbitrary data that can be then\nretrieved during the backward pass. Tensors should not be stored\ndirectly on ctx (though this is not currently enforced for\nbackward compatibility). Instead, tensors should be saved either with\n:func:ctx.save_for_backward if they are intended to be used in\nbackward (equivalently, vjp) or :func:ctx.save_for_forward\nif they are intended to be used for in jvp.
Define a formula for differentiating the operation with backward mode automatic differentiation.
\n\nThis function is to be overridden by all subclasses.\n(Defining this function is equivalent to defining the vjp function.)
It must accept a context :attr:ctx as the first argument, followed by\nas many outputs as the :func:forward returned (None will be passed in\nfor non tensor outputs of the forward function),\nand it should return as many tensors, as there were inputs to\n:func:forward. Each argument is the gradient w.r.t the given output,\nand each returned value should be the gradient w.r.t. the\ncorresponding input. If an input is not a Tensor or is a Tensor not\nrequiring grads, you can just pass None as a gradient for that input.
The context can be used to retrieve tensors saved during the forward\npass. It also has an attribute :attr:ctx.needs_input_grad as a tuple\nof booleans representing whether each input needs gradient. E.g.,\n:func:backward will have ctx.needs_input_grad[0] = True if the\nfirst input to :func:forward needs gradient computed w.r.t. the\noutput.
Copyright (c) Ant Financial Service Group and its affiliates.
\n"}, {"fullname": "linghe.facade.fp32_linear.FusedFp32GEMM", "modulename": "linghe.facade.fp32_linear", "qualname": "FusedFp32GEMM", "kind": "class", "doc": "Base class to create custom autograd.Function.
To create a custom autograd.Function, subclass this class and implement\nthe :meth:forward and :meth:backward static methods. Then, to use your custom\nop in the forward pass, call the class method apply. Do not call\n:meth:forward directly.
To ensure correctness and best performance, make sure you are calling the\ncorrect methods on ctx and validating your backward function using\n:func:torch.autograd.gradcheck.
See :ref:extending-autograd for more details on how to use this class.
Examples::
\n\n>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)\n>>> class Exp(Function):\n>>> @staticmethod\n>>> def forward(ctx, i):\n>>> result = i.exp()\n>>> ctx.save_for_backward(result)\n>>> return result\n>>>\n>>> @staticmethod\n>>> def backward(ctx, grad_output):\n>>> result, = ctx.saved_tensors\n>>> return grad_output * result\n>>>\n>>> # Use it by calling the apply method:\n>>> # xdoctest: +SKIP\n>>> output = Exp.apply(input)\n\n", "bases": "torch.autograd.function.Function"}, {"fullname": "linghe.facade.fp32_linear.FusedFp32GEMM.forward", "modulename": "linghe.facade.fp32_linear", "qualname": "FusedFp32GEMM.forward", "kind": "function", "doc": "Define the forward of the custom autograd Function.
\n\nThis function is to be overridden by all subclasses.\nThere are two ways to define forward:
\n\nUsage 1 (Combined forward and ctx)::
\n\n@staticmethod\ndef forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:\n pass\n\n\ncombining-forward-context for more detailsUsage 2 (Separate forward and ctx)::
\n\n@staticmethod\ndef forward(*args: Any, **kwargs: Any) -> Any:\n pass\n\n@staticmethod\ndef setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:\n pass\n\n\ntorch.autograd.Function.setup_context\nstaticmethod to handle setting up the ctx object.\noutput is the output of the forward, inputs are a Tuple of inputs\nto the forward.extending-autograd for more detailsThe context can be used to store arbitrary data that can be then\nretrieved during the backward pass. Tensors should not be stored\ndirectly on ctx (though this is not currently enforced for\nbackward compatibility). Instead, tensors should be saved either with\n:func:ctx.save_for_backward if they are intended to be used in\nbackward (equivalently, vjp) or :func:ctx.save_for_forward\nif they are intended to be used for in jvp.
Define a formula for differentiating the operation with backward mode automatic differentiation.
\n\nThis function is to be overridden by all subclasses.\n(Defining this function is equivalent to defining the vjp function.)
It must accept a context :attr:ctx as the first argument, followed by\nas many outputs as the :func:forward returned (None will be passed in\nfor non tensor outputs of the forward function),\nand it should return as many tensors, as there were inputs to\n:func:forward. Each argument is the gradient w.r.t the given output,\nand each returned value should be the gradient w.r.t. the\ncorresponding input. If an input is not a Tensor or is a Tensor not\nrequiring grads, you can just pass None as a gradient for that input.
The context can be used to retrieve tensors saved during the forward\npass. It also has an attribute :attr:ctx.needs_input_grad as a tuple\nof booleans representing whether each input needs gradient. E.g.,\n:func:backward will have ctx.needs_input_grad[0] = True if the\nfirst input to :func:forward needs gradient computed w.r.t. the\noutput.
Copyright (c) Ant Financial Service Group and its affiliates.
\n"}, {"fullname": "linghe.facade.loss.SoftmaxCrossEntropyFunction", "modulename": "linghe.facade.loss", "qualname": "SoftmaxCrossEntropyFunction", "kind": "class", "doc": "Base class to create custom autograd.Function.
To create a custom autograd.Function, subclass this class and implement\nthe :meth:forward and :meth:backward static methods. Then, to use your custom\nop in the forward pass, call the class method apply. Do not call\n:meth:forward directly.
To ensure correctness and best performance, make sure you are calling the\ncorrect methods on ctx and validating your backward function using\n:func:torch.autograd.gradcheck.
See :ref:extending-autograd for more details on how to use this class.
Examples::
\n\n>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)\n>>> class Exp(Function):\n>>> @staticmethod\n>>> def forward(ctx, i):\n>>> result = i.exp()\n>>> ctx.save_for_backward(result)\n>>> return result\n>>>\n>>> @staticmethod\n>>> def backward(ctx, grad_output):\n>>> result, = ctx.saved_tensors\n>>> return grad_output * result\n>>>\n>>> # Use it by calling the apply method:\n>>> # xdoctest: +SKIP\n>>> output = Exp.apply(input)\n\n", "bases": "torch.autograd.function.Function"}, {"fullname": "linghe.facade.loss.SoftmaxCrossEntropyFunction.forward", "modulename": "linghe.facade.loss", "qualname": "SoftmaxCrossEntropyFunction.forward", "kind": "function", "doc": "Define the forward of the custom autograd Function.
\n\nThis function is to be overridden by all subclasses.\nThere are two ways to define forward:
\n\nUsage 1 (Combined forward and ctx)::
\n\n@staticmethod\ndef forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:\n pass\n\n\ncombining-forward-context for more detailsUsage 2 (Separate forward and ctx)::
\n\n@staticmethod\ndef forward(*args: Any, **kwargs: Any) -> Any:\n pass\n\n@staticmethod\ndef setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:\n pass\n\n\ntorch.autograd.Function.setup_context\nstaticmethod to handle setting up the ctx object.\noutput is the output of the forward, inputs are a Tuple of inputs\nto the forward.extending-autograd for more detailsThe context can be used to store arbitrary data that can be then\nretrieved during the backward pass. Tensors should not be stored\ndirectly on ctx (though this is not currently enforced for\nbackward compatibility). Instead, tensors should be saved either with\n:func:ctx.save_for_backward if they are intended to be used in\nbackward (equivalently, vjp) or :func:ctx.save_for_forward\nif they are intended to be used for in jvp.
Define a formula for differentiating the operation with backward mode automatic differentiation.
\n\nThis function is to be overridden by all subclasses.\n(Defining this function is equivalent to defining the vjp function.)
It must accept a context :attr:ctx as the first argument, followed by\nas many outputs as the :func:forward returned (None will be passed in\nfor non tensor outputs of the forward function),\nand it should return as many tensors, as there were inputs to\n:func:forward. Each argument is the gradient w.r.t the given output,\nand each returned value should be the gradient w.r.t. the\ncorresponding input. If an input is not a Tensor or is a Tensor not\nrequiring grads, you can just pass None as a gradient for that input.
The context can be used to retrieve tensors saved during the forward\npass. It also has an attribute :attr:ctx.needs_input_grad as a tuple\nof booleans representing whether each input needs gradient. E.g.,\n:func:backward will have ctx.needs_input_grad[0] = True if the\nfirst input to :func:forward needs gradient computed w.r.t. the\noutput.
Base class to create custom autograd.Function.
To create a custom autograd.Function, subclass this class and implement\nthe :meth:forward and :meth:backward static methods. Then, to use your custom\nop in the forward pass, call the class method apply. Do not call\n:meth:forward directly.
To ensure correctness and best performance, make sure you are calling the\ncorrect methods on ctx and validating your backward function using\n:func:torch.autograd.gradcheck.
See :ref:extending-autograd for more details on how to use this class.
Examples::
\n\n>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)\n>>> class Exp(Function):\n>>> @staticmethod\n>>> def forward(ctx, i):\n>>> result = i.exp()\n>>> ctx.save_for_backward(result)\n>>> return result\n>>>\n>>> @staticmethod\n>>> def backward(ctx, grad_output):\n>>> result, = ctx.saved_tensors\n>>> return grad_output * result\n>>>\n>>> # Use it by calling the apply method:\n>>> # xdoctest: +SKIP\n>>> output = Exp.apply(input)\n\n", "bases": "torch.autograd.function.Function"}, {"fullname": "linghe.facade.loss.GradScalingFunction.forward", "modulename": "linghe.facade.loss", "qualname": "GradScalingFunction.forward", "kind": "function", "doc": "Define the forward of the custom autograd Function.
\n\nThis function is to be overridden by all subclasses.\nThere are two ways to define forward:
\n\nUsage 1 (Combined forward and ctx)::
\n\n@staticmethod\ndef forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:\n pass\n\n\ncombining-forward-context for more detailsUsage 2 (Separate forward and ctx)::
\n\n@staticmethod\ndef forward(*args: Any, **kwargs: Any) -> Any:\n pass\n\n@staticmethod\ndef setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:\n pass\n\n\ntorch.autograd.Function.setup_context\nstaticmethod to handle setting up the ctx object.\noutput is the output of the forward, inputs are a Tuple of inputs\nto the forward.extending-autograd for more detailsThe context can be used to store arbitrary data that can be then\nretrieved during the backward pass. Tensors should not be stored\ndirectly on ctx (though this is not currently enforced for\nbackward compatibility). Instead, tensors should be saved either with\n:func:ctx.save_for_backward if they are intended to be used in\nbackward (equivalently, vjp) or :func:ctx.save_for_forward\nif they are intended to be used for in jvp.
Define a formula for differentiating the operation with backward mode automatic differentiation.
\n\nThis function is to be overridden by all subclasses.\n(Defining this function is equivalent to defining the vjp function.)
It must accept a context :attr:ctx as the first argument, followed by\nas many outputs as the :func:forward returned (None will be passed in\nfor non tensor outputs of the forward function),\nand it should return as many tensors, as there were inputs to\n:func:forward. Each argument is the gradient w.r.t the given output,\nand each returned value should be the gradient w.r.t. the\ncorresponding input. If an input is not a Tensor or is a Tensor not\nrequiring grads, you can just pass None as a gradient for that input.
The context can be used to retrieve tensors saved during the forward\npass. It also has an attribute :attr:ctx.needs_input_grad as a tuple\nof booleans representing whether each input needs gradient. E.g.,\n:func:backward will have ctx.needs_input_grad[0] = True if the\nfirst input to :func:forward needs gradient computed w.r.t. the\noutput.
Copyright (c) Ant Financial Service Group and its affiliates.
\n"}, {"fullname": "linghe.facade.norm.RMSNormFunction", "modulename": "linghe.facade.norm", "qualname": "RMSNormFunction", "kind": "class", "doc": "Base class to create custom autograd.Function.
To create a custom autograd.Function, subclass this class and implement\nthe :meth:forward and :meth:backward static methods. Then, to use your custom\nop in the forward pass, call the class method apply. Do not call\n:meth:forward directly.
To ensure correctness and best performance, make sure you are calling the\ncorrect methods on ctx and validating your backward function using\n:func:torch.autograd.gradcheck.
See :ref:extending-autograd for more details on how to use this class.
Examples::
\n\n>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)\n>>> class Exp(Function):\n>>> @staticmethod\n>>> def forward(ctx, i):\n>>> result = i.exp()\n>>> ctx.save_for_backward(result)\n>>> return result\n>>>\n>>> @staticmethod\n>>> def backward(ctx, grad_output):\n>>> result, = ctx.saved_tensors\n>>> return grad_output * result\n>>>\n>>> # Use it by calling the apply method:\n>>> # xdoctest: +SKIP\n>>> output = Exp.apply(input)\n\n", "bases": "torch.autograd.function.Function"}, {"fullname": "linghe.facade.norm.RMSNormFunction.forward", "modulename": "linghe.facade.norm", "qualname": "RMSNormFunction.forward", "kind": "function", "doc": "Define the forward of the custom autograd Function.
\n\nThis function is to be overridden by all subclasses.\nThere are two ways to define forward:
\n\nUsage 1 (Combined forward and ctx)::
\n\n@staticmethod\ndef forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:\n pass\n\n\ncombining-forward-context for more detailsUsage 2 (Separate forward and ctx)::
\n\n@staticmethod\ndef forward(*args: Any, **kwargs: Any) -> Any:\n pass\n\n@staticmethod\ndef setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:\n pass\n\n\ntorch.autograd.Function.setup_context\nstaticmethod to handle setting up the ctx object.\noutput is the output of the forward, inputs are a Tuple of inputs\nto the forward.extending-autograd for more detailsThe context can be used to store arbitrary data that can be then\nretrieved during the backward pass. Tensors should not be stored\ndirectly on ctx (though this is not currently enforced for\nbackward compatibility). Instead, tensors should be saved either with\n:func:ctx.save_for_backward if they are intended to be used in\nbackward (equivalently, vjp) or :func:ctx.save_for_forward\nif they are intended to be used for in jvp.
Define a formula for differentiating the operation with backward mode automatic differentiation.
\n\nThis function is to be overridden by all subclasses.\n(Defining this function is equivalent to defining the vjp function.)
It must accept a context :attr:ctx as the first argument, followed by\nas many outputs as the :func:forward returned (None will be passed in\nfor non tensor outputs of the forward function),\nand it should return as many tensors, as there were inputs to\n:func:forward. Each argument is the gradient w.r.t the given output,\nand each returned value should be the gradient w.r.t. the\ncorresponding input. If an input is not a Tensor or is a Tensor not\nrequiring grads, you can just pass None as a gradient for that input.
The context can be used to retrieve tensors saved during the forward\npass. It also has an attribute :attr:ctx.needs_input_grad as a tuple\nof booleans representing whether each input needs gradient. E.g.,\n:func:backward will have ctx.needs_input_grad[0] = True if the\nfirst input to :func:forward needs gradient computed w.r.t. the\noutput.
Base class to create custom autograd.Function.
To create a custom autograd.Function, subclass this class and implement\nthe :meth:forward and :meth:backward static methods. Then, to use your custom\nop in the forward pass, call the class method apply. Do not call\n:meth:forward directly.
To ensure correctness and best performance, make sure you are calling the\ncorrect methods on ctx and validating your backward function using\n:func:torch.autograd.gradcheck.
See :ref:extending-autograd for more details on how to use this class.
Examples::
\n\n>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)\n>>> class Exp(Function):\n>>> @staticmethod\n>>> def forward(ctx, i):\n>>> result = i.exp()\n>>> ctx.save_for_backward(result)\n>>> return result\n>>>\n>>> @staticmethod\n>>> def backward(ctx, grad_output):\n>>> result, = ctx.saved_tensors\n>>> return grad_output * result\n>>>\n>>> # Use it by calling the apply method:\n>>> # xdoctest: +SKIP\n>>> output = Exp.apply(input)\n\n", "bases": "torch.autograd.function.Function"}, {"fullname": "linghe.facade.norm.GroupNormGateFunction.forward", "modulename": "linghe.facade.norm", "qualname": "GroupNormGateFunction.forward", "kind": "function", "doc": "Define the forward of the custom autograd Function.
\n\nThis function is to be overridden by all subclasses.\nThere are two ways to define forward:
\n\nUsage 1 (Combined forward and ctx)::
\n\n@staticmethod\ndef forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:\n pass\n\n\ncombining-forward-context for more detailsUsage 2 (Separate forward and ctx)::
\n\n@staticmethod\ndef forward(*args: Any, **kwargs: Any) -> Any:\n pass\n\n@staticmethod\ndef setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:\n pass\n\n\ntorch.autograd.Function.setup_context\nstaticmethod to handle setting up the ctx object.\noutput is the output of the forward, inputs are a Tuple of inputs\nto the forward.extending-autograd for more detailsThe context can be used to store arbitrary data that can be then\nretrieved during the backward pass. Tensors should not be stored\ndirectly on ctx (though this is not currently enforced for\nbackward compatibility). Instead, tensors should be saved either with\n:func:ctx.save_for_backward if they are intended to be used in\nbackward (equivalently, vjp) or :func:ctx.save_for_forward\nif they are intended to be used for in jvp.
Define a formula for differentiating the operation with backward mode automatic differentiation.
\n\nThis function is to be overridden by all subclasses.\n(Defining this function is equivalent to defining the vjp function.)
It must accept a context :attr:ctx as the first argument, followed by\nas many outputs as the :func:forward returned (None will be passed in\nfor non tensor outputs of the forward function),\nand it should return as many tensors, as there were inputs to\n:func:forward. Each argument is the gradient w.r.t the given output,\nand each returned value should be the gradient w.r.t. the\ncorresponding input. If an input is not a Tensor or is a Tensor not\nrequiring grads, you can just pass None as a gradient for that input.
The context can be used to retrieve tensors saved during the forward\npass. It also has an attribute :attr:ctx.needs_input_grad as a tuple\nof booleans representing whether each input needs gradient. E.g.,\n:func:backward will have ctx.needs_input_grad[0] = True if the\nfirst input to :func:forward needs gradient computed w.r.t. the\noutput.
Copyright (c) Ant Financial Service Group and its affiliates.
\n"}, {"fullname": "linghe.facade.rope.QkNormHalfRopeFunction", "modulename": "linghe.facade.rope", "qualname": "QkNormHalfRopeFunction", "kind": "class", "doc": "Base class to create custom autograd.Function.
To create a custom autograd.Function, subclass this class and implement\nthe :meth:forward and :meth:backward static methods. Then, to use your custom\nop in the forward pass, call the class method apply. Do not call\n:meth:forward directly.
To ensure correctness and best performance, make sure you are calling the\ncorrect methods on ctx and validating your backward function using\n:func:torch.autograd.gradcheck.
See :ref:extending-autograd for more details on how to use this class.
Examples::
\n\n>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)\n>>> class Exp(Function):\n>>> @staticmethod\n>>> def forward(ctx, i):\n>>> result = i.exp()\n>>> ctx.save_for_backward(result)\n>>> return result\n>>>\n>>> @staticmethod\n>>> def backward(ctx, grad_output):\n>>> result, = ctx.saved_tensors\n>>> return grad_output * result\n>>>\n>>> # Use it by calling the apply method:\n>>> # xdoctest: +SKIP\n>>> output = Exp.apply(input)\n\n", "bases": "torch.autograd.function.Function"}, {"fullname": "linghe.facade.rope.QkNormHalfRopeFunction.forward", "modulename": "linghe.facade.rope", "qualname": "QkNormHalfRopeFunction.forward", "kind": "function", "doc": "Define the forward of the custom autograd Function.
\n\nThis function is to be overridden by all subclasses.\nThere are two ways to define forward:
\n\nUsage 1 (Combined forward and ctx)::
\n\n@staticmethod\ndef forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:\n pass\n\n\ncombining-forward-context for more detailsUsage 2 (Separate forward and ctx)::
\n\n@staticmethod\ndef forward(*args: Any, **kwargs: Any) -> Any:\n pass\n\n@staticmethod\ndef setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:\n pass\n\n\ntorch.autograd.Function.setup_context\nstaticmethod to handle setting up the ctx object.\noutput is the output of the forward, inputs are a Tuple of inputs\nto the forward.extending-autograd for more detailsThe context can be used to store arbitrary data that can be then\nretrieved during the backward pass. Tensors should not be stored\ndirectly on ctx (though this is not currently enforced for\nbackward compatibility). Instead, tensors should be saved either with\n:func:ctx.save_for_backward if they are intended to be used in\nbackward (equivalently, vjp) or :func:ctx.save_for_forward\nif they are intended to be used for in jvp.
Define a formula for differentiating the operation with backward mode automatic differentiation.
\n\nThis function is to be overridden by all subclasses.\n(Defining this function is equivalent to defining the vjp function.)
It must accept a context :attr:ctx as the first argument, followed by\nas many outputs as the :func:forward returned (None will be passed in\nfor non tensor outputs of the forward function),\nand it should return as many tensors, as there were inputs to\n:func:forward. Each argument is the gradient w.r.t the given output,\nand each returned value should be the gradient w.r.t. the\ncorresponding input. If an input is not a Tensor or is a Tensor not\nrequiring grads, you can just pass None as a gradient for that input.
The context can be used to retrieve tensors saved during the forward\npass. It also has an attribute :attr:ctx.needs_input_grad as a tuple\nof booleans representing whether each input needs gradient. E.g.,\n:func:backward will have ctx.needs_input_grad[0] = True if the\nfirst input to :func:forward needs gradient computed w.r.t. the\noutput.
Copyright (c) Ant Financial Service Group and its affiliates.
\n"}, {"fullname": "linghe.facade.transpose.TransposeDim01Function", "modulename": "linghe.facade.transpose", "qualname": "TransposeDim01Function", "kind": "class", "doc": "Base class to create custom autograd.Function.
To create a custom autograd.Function, subclass this class and implement\nthe :meth:forward and :meth:backward static methods. Then, to use your custom\nop in the forward pass, call the class method apply. Do not call\n:meth:forward directly.
To ensure correctness and best performance, make sure you are calling the\ncorrect methods on ctx and validating your backward function using\n:func:torch.autograd.gradcheck.
See :ref:extending-autograd for more details on how to use this class.
Examples::
\n\n>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)\n>>> class Exp(Function):\n>>> @staticmethod\n>>> def forward(ctx, i):\n>>> result = i.exp()\n>>> ctx.save_for_backward(result)\n>>> return result\n>>>\n>>> @staticmethod\n>>> def backward(ctx, grad_output):\n>>> result, = ctx.saved_tensors\n>>> return grad_output * result\n>>>\n>>> # Use it by calling the apply method:\n>>> # xdoctest: +SKIP\n>>> output = Exp.apply(input)\n\n", "bases": "torch.autograd.function.Function"}, {"fullname": "linghe.facade.transpose.TransposeDim01Function.forward", "modulename": "linghe.facade.transpose", "qualname": "TransposeDim01Function.forward", "kind": "function", "doc": "Define the forward of the custom autograd Function.
\n\nThis function is to be overridden by all subclasses.\nThere are two ways to define forward:
\n\nUsage 1 (Combined forward and ctx)::
\n\n@staticmethod\ndef forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:\n pass\n\n\ncombining-forward-context for more detailsUsage 2 (Separate forward and ctx)::
\n\n@staticmethod\ndef forward(*args: Any, **kwargs: Any) -> Any:\n pass\n\n@staticmethod\ndef setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:\n pass\n\n\ntorch.autograd.Function.setup_context\nstaticmethod to handle setting up the ctx object.\noutput is the output of the forward, inputs are a Tuple of inputs\nto the forward.extending-autograd for more detailsThe context can be used to store arbitrary data that can be then\nretrieved during the backward pass. Tensors should not be stored\ndirectly on ctx (though this is not currently enforced for\nbackward compatibility). Instead, tensors should be saved either with\n:func:ctx.save_for_backward if they are intended to be used in\nbackward (equivalently, vjp) or :func:ctx.save_for_forward\nif they are intended to be used for in jvp.
Define a formula for differentiating the operation with backward mode automatic differentiation.
\n\nThis function is to be overridden by all subclasses.\n(Defining this function is equivalent to defining the vjp function.)
It must accept a context :attr:ctx as the first argument, followed by\nas many outputs as the :func:forward returned (None will be passed in\nfor non tensor outputs of the forward function),\nand it should return as many tensors, as there were inputs to\n:func:forward. Each argument is the gradient w.r.t the given output,\nand each returned value should be the gradient w.r.t. the\ncorresponding input. If an input is not a Tensor or is a Tensor not\nrequiring grads, you can just pass None as a gradient for that input.
The context can be used to retrieve tensors saved during the forward\npass. It also has an attribute :attr:ctx.needs_input_grad as a tuple\nof booleans representing whether each input needs gradient. E.g.,\n:func:backward will have ctx.needs_input_grad[0] = True if the\nfirst input to :func:forward needs gradient computed w.r.t. the\noutput.
Copyright (c) Ant Financial Service Group and its affiliates.
\n"}, {"fullname": "linghe.gemm.fp32_gemm.fp32_gemm_kernel", "modulename": "linghe.gemm.fp32_gemm", "qualname": "fp32_gemm_kernel", "kind": "function", "doc": "\n", "signature": "(\ta_ptr,\tb_ptr,\tc_ptr,\tM,\tN: int,\tK: int,\tBLOCK_SIZE_K: int,\tBLOCK_SIZE_M: int,\tBLOCK_SIZE_N: int):", "funcdef": "def"}, {"fullname": "linghe.gemm.fp32_gemm.triton_fp32_gemm", "modulename": "linghe.gemm.fp32_gemm", "qualname": "triton_fp32_gemm", "kind": "function", "doc": "\n", "signature": "(a: torch.Tensor, b: torch.Tensor):", "funcdef": "def"}, {"fullname": "linghe.gemm.fp32_gemm.scaled_fp32_gemm_kernel", "modulename": "linghe.gemm.fp32_gemm", "qualname": "scaled_fp32_gemm_kernel", "kind": "function", "doc": "\n", "signature": "(\ta_ptr,\tb_ptr,\tscale_ptr,\tc_ptr,\tM,\tN: int,\tK: int,\tBLOCK_SIZE_K: int,\tBLOCK_SIZE_M: int,\tBLOCK_SIZE_N: int):", "funcdef": "def"}, {"fullname": "linghe.gemm.fp32_gemm.triton_scaled_fp32_gemm", "modulename": "linghe.gemm.fp32_gemm", "qualname": "triton_scaled_fp32_gemm", "kind": "function", "doc": "\n", "signature": "(a: torch.Tensor, b: torch.Tensor, scale: torch.Tensor):", "funcdef": "def"}, {"fullname": "linghe.gemm.fp32_gemm.fp32_gemm_for_backward_kernel", "modulename": "linghe.gemm.fp32_gemm", "qualname": "fp32_gemm_for_backward_kernel", "kind": "function", "doc": "\n", "signature": "(\ta_ptr,\tb_ptr,\tc_ptr,\tM,\tN: int,\tK: int,\tACCUM: int,\tBLOCK_SIZE_K: int,\tBLOCK_SIZE_M: int,\tBLOCK_SIZE_N: int):", "funcdef": "def"}, {"fullname": "linghe.gemm.fp32_gemm.triton_fp32_gemm_for_backward", "modulename": "linghe.gemm.fp32_gemm", "qualname": "triton_fp32_gemm_for_backward", "kind": "function", "doc": "\n", "signature": "(\ta: torch.Tensor,\tb: torch.Tensor,\tc: Optional[torch.Tensor] = None,\taccum=False):", "funcdef": "def"}, {"fullname": "linghe.gemm.fp32_gemm.fp32_gemm_for_update_kernel", "modulename": "linghe.gemm.fp32_gemm", "qualname": "fp32_gemm_for_update_kernel", "kind": "function", "doc": "\n", "signature": "(\ta_ptr,\tb_ptr,\tc_ptr,\tM,\tN: int,\tK: int,\tBLOCK_SIZE_K: int,\tBLOCK_SIZE_M: int,\tBLOCK_SIZE_N: int):", "funcdef": "def"}, {"fullname": "linghe.gemm.fp32_gemm.triton_fp32_gemm_for_update", "modulename": "linghe.gemm.fp32_gemm", "qualname": "triton_fp32_gemm_for_update", "kind": "function", "doc": "\n", "signature": "(a: torch.Tensor, b: torch.Tensor):", "funcdef": "def"}, {"fullname": "linghe.gemm.fp32_gemm.scaled_fp32_gemm_for_update_kernel", "modulename": "linghe.gemm.fp32_gemm", "qualname": "scaled_fp32_gemm_for_update_kernel", "kind": "function", "doc": "\n", "signature": "(\ta_ptr,\tb_ptr,\tscale_ptr,\tc_ptr,\tM,\tN: int,\tK: int,\tBLOCK_SIZE_K: int,\tBLOCK_SIZE_M: int,\tBLOCK_SIZE_N: int):", "funcdef": "def"}, {"fullname": "linghe.gemm.fp32_gemm.triton_scaled_fp32_gemm_for_update", "modulename": "linghe.gemm.fp32_gemm", "qualname": "triton_scaled_fp32_gemm_for_update", "kind": "function", "doc": "\n", "signature": "(a: torch.Tensor, b: torch.Tensor, scale: torch.Tensor):", "funcdef": "def"}, {"fullname": "linghe.quant", "modulename": "linghe.quant", "kind": "module", "doc": "\n"}, {"fullname": "linghe.quant.block", "modulename": "linghe.quant.block", "kind": "module", "doc": "\n"}, {"fullname": "linghe.quant.block.block", "modulename": "linghe.quant.block.block", "kind": "module", "doc": "Copyright (c) Ant Financial Service Group and its affiliates.
\n"}, {"fullname": "linghe.quant.block.block.block_quant_kernel", "modulename": "linghe.quant.block.block", "qualname": "block_quant_kernel", "kind": "function", "doc": "\n", "signature": "(x_ptr, y_ptr, s_ptr, M, N, BLOCK_SIZE: int, ROUND: int):", "funcdef": "def"}, {"fullname": "linghe.quant.block.block.block_quant", "modulename": "linghe.quant.block.block", "qualname": "block_quant", "kind": "function", "doc": "\n", "signature": "(x, block_size=128, round_scale=False):", "funcdef": "def"}, {"fullname": "linghe.quant.block.group", "modulename": "linghe.quant.block.group", "kind": "module", "doc": "Copyright (c) Ant Financial Service Group and its affiliates.
\n"}, {"fullname": "linghe.quant.block.group.group_quant_kernel", "modulename": "linghe.quant.block.group", "qualname": "group_quant_kernel", "kind": "function", "doc": "\n", "signature": "(x_ptr, y_ptr, s_ptr, N, BLOCK_SIZE: int, K: int, ROUND: int):", "funcdef": "def"}, {"fullname": "linghe.quant.block.group.triton_group_quant", "modulename": "linghe.quant.block.group", "qualname": "triton_group_quant", "kind": "function", "doc": "\n", "signature": "(x, dtype=torch.float8_e4m3fn, group_size=128, round_scale=False):", "funcdef": "def"}, {"fullname": "linghe.quant.block.group.persist_group_quant_kernel", "modulename": "linghe.quant.block.group", "qualname": "persist_group_quant_kernel", "kind": "function", "doc": "\n", "signature": "(x_ptr, y_ptr, s_ptr, N, BLOCK_SIZE: int, B: int, K: int, ROUND: int):", "funcdef": "def"}, {"fullname": "linghe.quant.block.group.triton_persist_group_quant", "modulename": "linghe.quant.block.group", "qualname": "triton_persist_group_quant", "kind": "function", "doc": "\n", "signature": "(x, dtype=torch.float8_e4m3fn, group_size=128, round_scale=False):", "funcdef": "def"}, {"fullname": "linghe.quant.channel", "modulename": "linghe.quant.channel", "kind": "module", "doc": "\n"}, {"fullname": "linghe.quant.channel.channel", "modulename": "linghe.quant.channel.channel", "kind": "module", "doc": "Copyright (c) Ant Financial Service Group and its affiliates.
\n"}, {"fullname": "linghe.quant.channel.channel.row_quant_kernel", "modulename": "linghe.quant.channel.channel", "qualname": "row_quant_kernel", "kind": "function", "doc": "\n", "signature": "(x_ptr, q_ptr, s_ptr, M, N, BLOCK_SIZE: int, ROUND: int):", "funcdef": "def"}, {"fullname": "linghe.quant.channel.channel.triton_row_quant", "modulename": "linghe.quant.channel.channel", "qualname": "triton_row_quant", "kind": "function", "doc": "\n", "signature": "(x, round_scale=False):", "funcdef": "def"}, {"fullname": "linghe.quant.channel.channel.deprecated_tokenwise_row_quant_kernel", "modulename": "linghe.quant.channel.channel", "qualname": "deprecated_tokenwise_row_quant_kernel", "kind": "function", "doc": "\n", "signature": "(x_ptr, out_ptr, scale_ptr, M, T: int, N: int, ROUND: int):", "funcdef": "def"}, {"fullname": "linghe.quant.channel.channel.triton_deprecated_tokenwise_row_quant", "modulename": "linghe.quant.channel.channel", "qualname": "triton_deprecated_tokenwise_row_quant", "kind": "function", "doc": "\n", "signature": "(x, out=None, scale=None, round_scale=False):", "funcdef": "def"}, {"fullname": "linghe.quant.channel.channel.tokenwise_row_quant_kernel", "modulename": "linghe.quant.channel.channel", "qualname": "tokenwise_row_quant_kernel", "kind": "function", "doc": "\n", "signature": "(x_ptr, out_ptr, scale_ptr, N: int, ROUND: int):", "funcdef": "def"}, {"fullname": "linghe.quant.channel.channel.triton_tokenwise_row_quant", "modulename": "linghe.quant.channel.channel", "qualname": "triton_tokenwise_row_quant", "kind": "function", "doc": "\n", "signature": "(x, out=None, scale=None, round_scale=False):", "funcdef": "def"}, {"fullname": "linghe.quant.channel.channel.transpose_row_quant_kernel", "modulename": "linghe.quant.channel.channel", "qualname": "transpose_row_quant_kernel", "kind": "function", "doc": "\n", "signature": "(x_ptr, q_ptr, s_ptr, M, N, H: int, W: int, ROUND: int):", "funcdef": "def"}, {"fullname": "linghe.quant.channel.channel.triton_transpose_row_quant", "modulename": "linghe.quant.channel.channel", "qualname": "triton_transpose_row_quant", "kind": "function", "doc": "\n", "signature": "(x, side=0, round_scale=False):", "funcdef": "def"}, {"fullname": "linghe.quant.channel.channel.triton_channel_quant_nt", "modulename": "linghe.quant.channel.channel", "qualname": "triton_channel_quant_nt", "kind": "function", "doc": "\n", "signature": "(x, w):", "funcdef": "def"}, {"fullname": "linghe.quant.channel.channel.triton_channel_quant_nn", "modulename": "linghe.quant.channel.channel", "qualname": "triton_channel_quant_nn", "kind": "function", "doc": "\n", "signature": "(y, w):", "funcdef": "def"}, {"fullname": "linghe.quant.channel.channel.triton_channel_quant_tn", "modulename": "linghe.quant.channel.channel", "qualname": "triton_channel_quant_tn", "kind": "function", "doc": "\n", "signature": "(y, x):", "funcdef": "def"}, {"fullname": "linghe.quant.channel.channel.channel_quant_forward", "modulename": "linghe.quant.channel.channel", "qualname": "channel_quant_forward", "kind": "function", "doc": "\n", "signature": "(x, w):", "funcdef": "def"}, {"fullname": "linghe.quant.channel.channel.channel_quant_backward", "modulename": "linghe.quant.channel.channel", "qualname": "channel_quant_backward", "kind": "function", "doc": "\n", "signature": "(y, w):", "funcdef": "def"}, {"fullname": "linghe.quant.channel.channel.channel_quant_update", "modulename": "linghe.quant.channel.channel", "qualname": "channel_quant_update", "kind": "function", "doc": "\n", "signature": "(y, x):", "funcdef": "def"}, {"fullname": "linghe.quant.channel.channel.fp8_channel_f_and_b", "modulename": "linghe.quant.channel.channel", "qualname": "fp8_channel_f_and_b", "kind": "function", "doc": "\n", "signature": "(x, w, y):", "funcdef": "def"}, {"fullname": "linghe.utils", "modulename": "linghe.utils", "kind": "module", "doc": "\n"}, {"fullname": "linghe.utils.add", "modulename": "linghe.utils.add", "kind": "module", "doc": "Copyright (c) Ant Financial Service Group and its affiliates.
\n"}, {"fullname": "linghe.utils.add.inplace_add_kernel", "modulename": "linghe.utils.add", "qualname": "inplace_add_kernel", "kind": "function", "doc": "\n", "signature": "(x_ptr, y_ptr, M, N, H: int, W: int, EVEN: int, ACCUM: int):", "funcdef": "def"}, {"fullname": "linghe.utils.add.triton_inplace_add", "modulename": "linghe.utils.add", "qualname": "triton_inplace_add", "kind": "function", "doc": "inplace add y to x\nArgs:\n x: Tensor\n y: Tensor\n accum: whether accum y to x
\n\nReturns: x += y if accum=True else x.copy_(y)
\n", "signature": "(x: torch.Tensor, y: torch.Tensor, accum: bool = True):", "funcdef": "def"}, {"fullname": "linghe.utils.dot", "modulename": "linghe.utils.dot", "kind": "module", "doc": "Copyright (c) Ant Financial Service Group and its affiliates.
\n"}, {"fullname": "linghe.utils.dot.dot_kernel", "modulename": "linghe.utils.dot", "qualname": "dot_kernel", "kind": "function", "doc": "\n", "signature": "(x_ptr, y_ptr, sum_ptr, M, N, H: int, W: int):", "funcdef": "def"}, {"fullname": "linghe.utils.dot.triton_dot", "modulename": "linghe.utils.dot", "qualname": "triton_dot", "kind": "function", "doc": "\n", "signature": "(x, y):", "funcdef": "def"}, {"fullname": "linghe.utils.dot.mix_precise_dot_kernel", "modulename": "linghe.utils.dot", "qualname": "mix_precise_dot_kernel", "kind": "function", "doc": "\n", "signature": "(\tx_ptr,\tq_ptr,\tsum_ptr,\tsmooth_scale_ptr,\tquant_scale_ptr,\tM,\tN,\tH: int,\tW: int):", "funcdef": "def"}, {"fullname": "linghe.utils.dot.triton_mix_precise_dot", "modulename": "linghe.utils.dot", "qualname": "triton_mix_precise_dot", "kind": "function", "doc": "\n", "signature": "(x, q, smooth_scale, quant_scale, reverse=False):", "funcdef": "def"}, {"fullname": "linghe.utils.gather", "modulename": "linghe.utils.gather", "kind": "module", "doc": "Copyright (c) Ant Financial Service Group and its affiliates.
\n"}, {"fullname": "linghe.utils.gather.block_count_kernel", "modulename": "linghe.utils.gather", "qualname": "block_count_kernel", "kind": "function", "doc": "\n", "signature": "(map_ptr, count_ptr, M, B, T: int, b: int, E: int):", "funcdef": "def"}, {"fullname": "linghe.utils.gather.make_row_id_map_kernel", "modulename": "linghe.utils.gather", "qualname": "make_row_id_map_kernel", "kind": "function", "doc": "\n", "signature": "(map_ptr, count_ptr, output_ptr, M, B, P, T: int, b: int, E: int):", "funcdef": "def"}, {"fullname": "linghe.utils.gather.triton_make_row_id_map", "modulename": "linghe.utils.gather", "qualname": "triton_make_row_id_map", "kind": "function", "doc": "\n", "signature": "(routing_map: torch.Tensor, multiple_of: int = 1):", "funcdef": "def"}, {"fullname": "linghe.utils.gather.make_row_id_map_and_indices_kernel", "modulename": "linghe.utils.gather", "qualname": "make_row_id_map_and_indices_kernel", "kind": "function", "doc": "\n", "signature": "(\tmap_ptr,\tcount_ptr,\trow_map_ptr,\trow_indices_ptr,\tM,\tB,\tP,\tT: int,\tb: int,\tE: int):", "funcdef": "def"}, {"fullname": "linghe.utils.gather.triton_make_row_id_map_and_indices", "modulename": "linghe.utils.gather", "qualname": "triton_make_row_id_map_and_indices", "kind": "function", "doc": "\n", "signature": "(routing_map: torch.Tensor, num_out_tokens: int, multiple_of: int = 1):", "funcdef": "def"}, {"fullname": "linghe.utils.gather.index_select_kernel", "modulename": "linghe.utils.gather", "qualname": "index_select_kernel", "kind": "function", "doc": "\n", "signature": "(\tx_ptr,\tout_ptr,\tscale_ptr,\tscale_out_ptr,\tindex_ptr,\tM,\tT,\tN: int,\tSCALE: int):", "funcdef": "def"}, {"fullname": "linghe.utils.gather.triton_index_select", "modulename": "linghe.utils.gather", "qualname": "triton_index_select", "kind": "function", "doc": "\n", "signature": "(x, indices, scale=None, out=None, scale_out=None):", "funcdef": "def"}, {"fullname": "linghe.utils.gather.permute_with_mask_map_kernel", "modulename": "linghe.utils.gather", "qualname": "permute_with_mask_map_kernel", "kind": "function", "doc": "\n", "signature": "(\tdata_ptr,\tscale_ptr,\tprobs_ptr,\tmask_map_ptr,\toutput_data_ptr,\toutput_scale_ptr,\toutput_probs_ptr,\tnum_experts: int,\tN: int,\ths: int,\tSCALE: int,\tPROB: int):", "funcdef": "def"}, {"fullname": "linghe.utils.gather.fill_padded_token_with_zero_kernel", "modulename": "linghe.utils.gather", "qualname": "fill_padded_token_with_zero_kernel", "kind": "function", "doc": "\n", "signature": "(\tdata_ptr,\tscale_ptr,\tprobs_ptr,\tmax_indices_ptr,\ttoken_per_expert_ptr,\tN: int,\ths: int,\tSCALE: int,\tPROB: int):", "funcdef": "def"}, {"fullname": "linghe.utils.gather.triton_permute_with_mask_map", "modulename": "linghe.utils.gather", "qualname": "triton_permute_with_mask_map", "kind": "function", "doc": "\n", "signature": "(\tinp: torch.Tensor,\tscale: torch.Tensor,\tprobs: torch.Tensor,\trow_id_map: torch.Tensor,\tnum_out_tokens: int,\tcontiguous: bool = True,\ttokens_per_expert: Optional[torch.Tensor] = None):", "funcdef": "def"}, {"fullname": "linghe.utils.gather.batch_smooth_transpose_smooth_permute_kernel", "modulename": "linghe.utils.gather", "qualname": "batch_smooth_transpose_smooth_permute_kernel", "kind": "function", "doc": "\n", "signature": "(\tx_ptr,\tscale_ptr,\toss_ptr,\tss_ptr,\tindex_ptr,\tcount_ptr,\taccum_ptr,\tq_ptr,\tqs_ptr,\tN: int,\tE: int,\tH: int,\tW: int,\tSMOOTHED: int,\tROUND: int):", "funcdef": "def"}, {"fullname": "linghe.utils.gather.triton_batch_transpose_smooth_permute_with_indices", "modulename": "linghe.utils.gather", "qualname": "triton_batch_transpose_smooth_permute_with_indices", "kind": "function", "doc": "\n", "signature": "(\tx,\tscale,\torg_smooth_scale,\tsmooth_scales,\tindices,\ttoken_count_per_expert,\tsplits,\tx_q=None,\tx_scale=None,\tround_scale=False):", "funcdef": "def"}, {"fullname": "linghe.utils.gather.smooth_weighted_permute_with_indices_kernel", "modulename": "linghe.utils.gather", "qualname": "smooth_weighted_permute_with_indices_kernel", "kind": "function", "doc": "\n", "signature": "(\tgrads_ptr,\ttokens_ptr,\tq_ptr,\tss_ptr,\tqs_ptr,\tcount_ptr,\taccum_ptr,\tindex_ptr,\tsum_ptr,\tM,\tN: int,\tREVERSE: int,\tROUND: int):", "funcdef": "def"}, {"fullname": "linghe.utils.gather.triton_smooth_weighted_permute_with_indices", "modulename": "linghe.utils.gather", "qualname": "triton_smooth_weighted_permute_with_indices", "kind": "function", "doc": "\n", "signature": "(\tgrads,\ttokens,\tsmooth_scales,\ttoken_count_per_expert,\tindices,\tx_q=None,\tx_scale=None,\tx_sum=None,\treverse=False,\tround_scale=False):", "funcdef": "def"}, {"fullname": "linghe.utils.gather.smooth_permute_with_indices_kernel", "modulename": "linghe.utils.gather", "qualname": "smooth_permute_with_indices_kernel", "kind": "function", "doc": "\n", "signature": "(\tgrads_data_ptr,\tgrads_scale_ptr,\tq_ptr,\tss_ptr,\tqs_ptr,\tcount_ptr,\taccum_ptr,\tindex_ptr,\tN: int,\ths: int,\tREVERSE: int,\tROUND: int,\tGROUP: int):", "funcdef": "def"}, {"fullname": "linghe.utils.gather.triton_smooth_permute_with_indices", "modulename": "linghe.utils.gather", "qualname": "triton_smooth_permute_with_indices", "kind": "function", "doc": "\n", "signature": "(\tgrad_data,\tgrad_scale,\tsmooth_scales,\ttoken_count_per_expert,\tindices,\tx_q=None,\tx_scale=None,\treverse=False,\tround_scale=False):", "funcdef": "def"}, {"fullname": "linghe.utils.gather.smooth_permute_with_mask_map_kernel", "modulename": "linghe.utils.gather", "qualname": "smooth_permute_with_mask_map_kernel", "kind": "function", "doc": "\n", "signature": "(\tgrads_data_ptr,\tquant_data_ptr,\tmask_map_ptr,\tgrads_scale_ptr,\tsmooth_scale_ptr,\tquant_scale_ptr,\tM,\tT,\tN: int,\ths: int,\tREVERSE: int,\tROUND: int):", "funcdef": "def"}, {"fullname": "linghe.utils.gather.triton_smooth_permute_with_mask_map", "modulename": "linghe.utils.gather", "qualname": "triton_smooth_permute_with_mask_map", "kind": "function", "doc": "\n", "signature": "(\tinp: torch.Tensor,\trow_id_map: torch.Tensor,\tscale: torch.Tensor,\tnum_tokens: int,\tnum_experts: int,\tnum_out_tokens: int,\thidden_size: int,\tsmooth_scales: torch.Tensor,\treverse=True,\tround_scale=False):", "funcdef": "def"}, {"fullname": "linghe.utils.gather.deprecated_smooth_permute_with_mask_map_kernel", "modulename": "linghe.utils.gather", "qualname": "deprecated_smooth_permute_with_mask_map_kernel", "kind": "function", "doc": "\n", "signature": "(\tgrads_data_ptr,\tquant_data_ptr,\tmask_map_ptr,\tsmooth_scale_ptr,\tquant_scale_ptr,\tM,\tT,\tN: int,\tREVERSE: int,\tROUND: int):", "funcdef": "def"}, {"fullname": "linghe.utils.gather.triton_deprecated_smooth_permute_with_mask_map", "modulename": "linghe.utils.gather", "qualname": "triton_deprecated_smooth_permute_with_mask_map", "kind": "function", "doc": "\n", "signature": "(\tinp: torch.Tensor,\trow_id_map: torch.Tensor,\tnum_tokens: int,\tnum_experts: int,\tnum_out_tokens: int,\thidden_size: int,\tsmooth_scales: torch.Tensor,\treverse=True,\tround_scale=False):", "funcdef": "def"}, {"fullname": "linghe.utils.loss", "modulename": "linghe.utils.loss", "kind": "module", "doc": "Copyright (c) Ant Financial Service Group and its affiliates.
\n"}, {"fullname": "linghe.utils.loss.softmax_cross_entropy_forward_kernel", "modulename": "linghe.utils.loss", "qualname": "softmax_cross_entropy_forward_kernel", "kind": "function", "doc": "\n", "signature": "(\tlogit_ptr,\tlabel_ptr,\tloss_ptr,\tsum_exp_ptr,\tmax_logit_ptr,\tN,\tB: int):", "funcdef": "def"}, {"fullname": "linghe.utils.loss.triton_softmax_cross_entropy_forward", "modulename": "linghe.utils.loss", "qualname": "triton_softmax_cross_entropy_forward", "kind": "function", "doc": "\n", "signature": "(logits, labels):", "funcdef": "def"}, {"fullname": "linghe.utils.loss.softmax_cross_entropy_backward_kernel", "modulename": "linghe.utils.loss", "qualname": "softmax_cross_entropy_backward_kernel", "kind": "function", "doc": "\n", "signature": "(\tlogit_ptr,\tlabel_ptr,\tsum_exp_ptr,\tmax_logit_ptr,\tinput_grad_ptr,\toutput_grad_ptr,\tN,\tB: int):", "funcdef": "def"}, {"fullname": "linghe.utils.loss.triton_softmax_cross_entropy_backward", "modulename": "linghe.utils.loss", "qualname": "triton_softmax_cross_entropy_backward", "kind": "function", "doc": "\n", "signature": "(logits, labels, sum_exp, max_logit, input_grad, output_grad=None):", "funcdef": "def"}, {"fullname": "linghe.utils.norm", "modulename": "linghe.utils.norm", "kind": "module", "doc": "\n"}, {"fullname": "linghe.utils.norm.rms_norm_forward_kernel", "modulename": "linghe.utils.norm", "qualname": "rms_norm_forward_kernel", "kind": "function", "doc": "\n", "signature": "(x_ptr, weight_ptr, out_ptr, eps, M, T, N: int, W: int):", "funcdef": "def"}, {"fullname": "linghe.utils.norm.triton_rms_norm_forward", "modulename": "linghe.utils.norm", "qualname": "triton_rms_norm_forward", "kind": "function", "doc": "\n", "signature": "(x, weight, eps=1e-06, out=None):", "funcdef": "def"}, {"fullname": "linghe.utils.norm.rms_norm_backward_kernel", "modulename": "linghe.utils.norm", "qualname": "rms_norm_backward_kernel", "kind": "function", "doc": "\n", "signature": "(\tgrad_output_ptr,\tx_ptr,\tw_ptr,\tdx_ptr,\tdw_ptr,\teps,\tM,\tT,\tN: int,\tW: int):", "funcdef": "def"}, {"fullname": "linghe.utils.norm.triton_rms_norm_backward", "modulename": "linghe.utils.norm", "qualname": "triton_rms_norm_backward", "kind": "function", "doc": "\n", "signature": "(grad_output, x, w, eps=1e-06):", "funcdef": "def"}, {"fullname": "linghe.utils.norm.rms_norm_and_block_quant_forward_kernel", "modulename": "linghe.utils.norm", "qualname": "rms_norm_and_block_quant_forward_kernel", "kind": "function", "doc": "\n", "signature": "(\tx_ptr,\tweight_ptr,\tout_ptr,\tscale_ptr,\ttranspose_output_ptr,\ttranspose_scale_ptr,\trms_ptr,\teps,\tM,\tT: int,\tN: int,\tnb: int,\tW: int,\tH: int,\tROUND: int):", "funcdef": "def"}, {"fullname": "linghe.utils.norm.rms_norm_and_block_quant_forward_n_kernel", "modulename": "linghe.utils.norm", "qualname": "rms_norm_and_block_quant_forward_n_kernel", "kind": "function", "doc": "\n", "signature": "(\tx_ptr,\tweight_ptr,\tout_ptr,\tscale_ptr,\trms_ptr,\teps,\tM: int,\tT: int,\tN: int,\tnb: int,\tW: int,\tROUND: int):", "funcdef": "def"}, {"fullname": "linghe.utils.norm.rms_norm_and_block_quant_forward_t_kernel", "modulename": "linghe.utils.norm", "qualname": "rms_norm_and_block_quant_forward_t_kernel", "kind": "function", "doc": "\n", "signature": "(\tx_ptr,\tweight_ptr,\ttranspose_output_ptr,\ttranspose_scale_ptr,\trms_ptr,\tM,\tN,\tW: int,\tROUND: int):", "funcdef": "def"}, {"fullname": "linghe.utils.norm.triton_rms_norm_and_block_quant_forward", "modulename": "linghe.utils.norm", "qualname": "triton_rms_norm_and_block_quant_forward", "kind": "function", "doc": "Fused RMSNorm forward and block quantization.\nArgs:\n x: Input tensor, shape [M, N]\n weight: RMSNorm weight, shape [N]\n eps: epsilon value for L2 normalization.\n out: output of quantization data\n scale: output of quantization scale.\n rms: output of rms\n round_scale: Set whether to force power of 2 scales.\n output_mode: one of {0, 1, 2}.\n 0: only output non-transpose tensor\n 1: only output transposed tensor\n 2: return both\nReturns:\n out: quantization data\n scale: quantization scale\n rms: Reciprocal of the root mean square of the input calculated over the last dimension.\n transpose_output: quantization data of transposed gradient\n transpose_scale: quantization scale of transposed gradient
\n", "signature": "(\tx: torch.Tensor,\tweight: torch.Tensor,\teps: float = 1e-06,\tout: Optional[torch.Tensor] = None,\tscale: Optional[torch.Tensor] = None,\trms: Optional[torch.Tensor] = None,\tround_scale: bool = False,\toutput_mode: int = 2):", "funcdef": "def"}, {"fullname": "linghe.utils.norm.group_norm_gate_forward_kernel", "modulename": "linghe.utils.norm", "qualname": "group_norm_gate_forward_kernel", "kind": "function", "doc": "\n", "signature": "(\tx_ptr,\tgate_ptr,\tweight_ptr,\tout_ptr,\teps,\tbs,\tlength,\tDIM: int,\tD: int,\tGROUP_SIZE: int):", "funcdef": "def"}, {"fullname": "linghe.utils.norm.triton_group_norm_gate_forward", "modulename": "linghe.utils.norm", "qualname": "triton_group_norm_gate_forward", "kind": "function", "doc": "norm and gate in linear attention\nArgs:\n x:\n gate:\n weight:\n eps:\n group_size:
\n\nReturns:
\n", "signature": "(x: torch.Tensor, gate, weight, eps=1e-06, group_size=4):", "funcdef": "def"}, {"fullname": "linghe.utils.norm.group_rms_gate_backward_kernel", "modulename": "linghe.utils.norm", "qualname": "group_rms_gate_backward_kernel", "kind": "function", "doc": "\n", "signature": "(\tgrad_output_ptr,\tx_ptr,\tgate_ptr,\tw_ptr,\tdx_ptr,\tdg_ptr,\tdw_ptr,\teps,\tbs,\tlength,\tDIM: int,\tD: int,\tGROUP_SIZE: int,\tT: int):", "funcdef": "def"}, {"fullname": "linghe.utils.norm.triton_group_norm_gate_backward", "modulename": "linghe.utils.norm", "qualname": "triton_group_norm_gate_backward", "kind": "function", "doc": "\n", "signature": "(grad_output, x, gate, weight, eps=1e-06, group_size=4):", "funcdef": "def"}, {"fullname": "linghe.utils.rearange", "modulename": "linghe.utils.rearange", "kind": "module", "doc": "Copyright (c) Ant Financial Service Group and its affiliates.
\n"}, {"fullname": "linghe.utils.rearange.split_and_cat_kernel", "modulename": "linghe.utils.rearange", "qualname": "split_and_cat_kernel", "kind": "function", "doc": "\n", "signature": "(\tx_ptr,\ty_ptr,\tscale_ptr,\tscale_output_ptr,\tcount_ptr,\taccum_ptr,\trev_accum_ptr,\tindex_ptr,\tM,\tN: int,\tSCALE: int,\tK: int):", "funcdef": "def"}, {"fullname": "linghe.utils.rearange.triton_split_and_cat", "modulename": "linghe.utils.rearange", "qualname": "triton_split_and_cat", "kind": "function", "doc": "\n", "signature": "(x, counts, indices, scales=None):", "funcdef": "def"}, {"fullname": "linghe.utils.reduce", "modulename": "linghe.utils.reduce", "kind": "module", "doc": "Copyright (c) Ant Financial Service Group and its affiliates.
\n"}, {"fullname": "linghe.utils.reduce.abs_max_kernel", "modulename": "linghe.utils.reduce", "qualname": "abs_max_kernel", "kind": "function", "doc": "\n", "signature": "(\tx_ptr,\tscale_ptr,\tsmooth_scale_ptr,\toutput_ptr,\tmin_value,\tM,\tN,\tH: int,\tW: int,\tEVEN: int,\tQUANTIZED: int):", "funcdef": "def"}, {"fullname": "linghe.utils.reduce.triton_abs_max", "modulename": "linghe.utils.reduce", "qualname": "triton_abs_max", "kind": "function", "doc": "\n", "signature": "(x, scale=None, smooth_scale=None, min_value=1e-30, axis=0):", "funcdef": "def"}, {"fullname": "linghe.utils.reduce.batch_count_zero_kernel", "modulename": "linghe.utils.reduce", "qualname": "batch_count_zero_kernel", "kind": "function", "doc": "\n", "signature": "(input_ptrs, size_ptr, count_ptr, B: int):", "funcdef": "def"}, {"fullname": "linghe.utils.reduce.triton_batch_count_zero", "modulename": "linghe.utils.reduce", "qualname": "triton_batch_count_zero", "kind": "function", "doc": "\n", "signature": "(xs):", "funcdef": "def"}, {"fullname": "linghe.utils.reduce.batch_sum_with_ord_kernel", "modulename": "linghe.utils.reduce", "qualname": "batch_sum_with_ord_kernel", "kind": "function", "doc": "\n", "signature": "(input_ptrs, size_ptr, count_ptr, B: int, ORD: int):", "funcdef": "def"}, {"fullname": "linghe.utils.reduce.triton_batch_sum_with_ord", "modulename": "linghe.utils.reduce", "qualname": "triton_batch_sum_with_ord", "kind": "function", "doc": "\n", "signature": "(xs, ord=2):", "funcdef": "def"}, {"fullname": "linghe.utils.rope", "modulename": "linghe.utils.rope", "kind": "module", "doc": "Copyright (c) Ant Financial Service Group and its affiliates.
\n"}, {"fullname": "linghe.utils.rope.half_rope_forward_kernel", "modulename": "linghe.utils.rope", "qualname": "half_rope_forward_kernel", "kind": "function", "doc": "\n", "signature": "(\tq_ptr,\tk_ptr,\tfreqs_ptr,\tqo_ptr,\tko_ptr,\tB,\tq_stride,\tk_stride,\tH: int,\th: int,\tD: int,\td: int):", "funcdef": "def"}, {"fullname": "linghe.utils.rope.triton_half_rope_forward", "modulename": "linghe.utils.rope", "qualname": "triton_half_rope_forward", "kind": "function", "doc": "\n", "signature": "(q, k, freqs):", "funcdef": "def"}, {"fullname": "linghe.utils.rope.half_rope_backward_kernel", "modulename": "linghe.utils.rope", "qualname": "half_rope_backward_kernel", "kind": "function", "doc": "\n", "signature": "(q_ptr, k_ptr, freqs_ptr, B, H: int, h: int, D: int, d: int):", "funcdef": "def"}, {"fullname": "linghe.utils.rope.triton_half_rope_backward", "modulename": "linghe.utils.rope", "qualname": "triton_half_rope_backward", "kind": "function", "doc": "\n", "signature": "(q_grad, k_grad, freqs, inplace=False):", "funcdef": "def"}, {"fullname": "linghe.utils.rope.qk_norm_and_half_rope_forward_kernel", "modulename": "linghe.utils.rope", "qualname": "qk_norm_and_half_rope_forward_kernel", "kind": "function", "doc": "\n", "signature": "(\tqkv_ptr,\tq_norm_weight_ptr,\tk_norm_weight_ptr,\tfreqs_ptr,\tqo_ptr,\tko_ptr,\tvo_ptr,\tB,\tstride,\teps,\tH: int,\th: int,\tD: int,\td: int,\tinterleave: int):", "funcdef": "def"}, {"fullname": "linghe.utils.rope.triton_qk_norm_and_half_rope_forward", "modulename": "linghe.utils.rope", "qualname": "triton_qk_norm_and_half_rope_forward", "kind": "function", "doc": "\n", "signature": "(\tqkv,\tq_norm_weight,\tk_norm_weight,\tfreqs,\tH=32,\th=4,\teps=1e-06,\tinterleave=True,\ttranspose=False):", "funcdef": "def"}, {"fullname": "linghe.utils.rope.qk_norm_and_half_rope_backward_kernel", "modulename": "linghe.utils.rope", "qualname": "qk_norm_and_half_rope_backward_kernel", "kind": "function", "doc": "\n", "signature": "(\tgq_ptr,\tgk_ptr,\tgv_ptr,\tqkv_ptr,\tq_norm_weight_ptr,\tk_norm_weight_ptr,\tfreqs_ptr,\tdqkv_ptr,\tdqw_ptr,\tdkw_ptr,\tB,\tstride,\teps,\tH: int,\th: int,\tD: int,\td: int,\tinterleave: int):", "funcdef": "def"}, {"fullname": "linghe.utils.rope.triton_qk_norm_and_half_rope_backward", "modulename": "linghe.utils.rope", "qualname": "triton_qk_norm_and_half_rope_backward", "kind": "function", "doc": "\n", "signature": "(\tgq,\tgk,\tgv,\tqkv,\tq_norm_weight,\tk_norm_weight,\tfreqs,\teps=1e-06,\ttranspose=False,\tinterleave=True):", "funcdef": "def"}, {"fullname": "linghe.utils.scatter", "modulename": "linghe.utils.scatter", "kind": "module", "doc": "Copyright (c) Ant Financial Service Group and its affiliates.
\n"}, {"fullname": "linghe.utils.scatter.aligned_scatter_add_kernel", "modulename": "linghe.utils.scatter", "qualname": "aligned_scatter_add_kernel", "kind": "function", "doc": "\n", "signature": "(\tx_ptr,\to_ptr,\tindices_ptr,\tweights_ptr,\tM,\tN: int,\tK: int,\tSCALE: int):", "funcdef": "def"}, {"fullname": "linghe.utils.scatter.triton_aligned_scatter_add", "modulename": "linghe.utils.scatter", "qualname": "triton_aligned_scatter_add", "kind": "function", "doc": "\n", "signature": "(x, outputs, indices, weights=None):", "funcdef": "def"}, {"fullname": "linghe.utils.scatter.scatter_add_kernel", "modulename": "linghe.utils.scatter", "qualname": "scatter_add_kernel", "kind": "function", "doc": "\n", "signature": "(x_ptr, o_ptr, indices_ptr, M, T, N: int):", "funcdef": "def"}, {"fullname": "linghe.utils.scatter.fp32_to_bf16_kernel", "modulename": "linghe.utils.scatter", "qualname": "fp32_to_bf16_kernel", "kind": "function", "doc": "\n", "signature": "(x_ptr, o_ptr, M, T, N: int):", "funcdef": "def"}, {"fullname": "linghe.utils.scatter.triton_scatter_add", "modulename": "linghe.utils.scatter", "qualname": "triton_scatter_add", "kind": "function", "doc": "\n", "signature": "(x, outputs, indices):", "funcdef": "def"}, {"fullname": "linghe.utils.scatter.unpermute_with_mask_map_kernel", "modulename": "linghe.utils.scatter", "qualname": "unpermute_with_mask_map_kernel", "kind": "function", "doc": "\n", "signature": "(\tgrads_ptr,\tprobs_ptr,\tmask_map_ptr,\toutput_ptr,\toutput_probs_ptr,\tnum_experts: int,\tN: int,\tPROB: int):", "funcdef": "def"}, {"fullname": "linghe.utils.scatter.triton_unpermute_with_mask_map", "modulename": "linghe.utils.scatter", "qualname": "triton_unpermute_with_mask_map", "kind": "function", "doc": "\n", "signature": "(grad: torch.Tensor, row_id_map: torch.Tensor, probs: torch.Tensor):", "funcdef": "def"}, {"fullname": "linghe.utils.silu", "modulename": "linghe.utils.silu", "kind": "module", "doc": "Copyright (c) Ant Financial Service Group and its affiliates.
\n"}, {"fullname": "linghe.utils.silu.silu_and_block_quant_forward_kernel", "modulename": "linghe.utils.silu", "qualname": "silu_and_block_quant_forward_kernel", "kind": "function", "doc": "\n", "signature": "(\tx_ptr,\tout_ptr,\tscale_ptr,\ttranspose_output_ptr,\ttranspose_scale_ptr,\tM,\tn: int,\tROUND: int,\tOUTPUT_MODE: int):", "funcdef": "def"}, {"fullname": "linghe.utils.silu.triton_silu_and_block_quant_forward", "modulename": "linghe.utils.silu", "qualname": "triton_silu_and_block_quant_forward", "kind": "function", "doc": "\n", "signature": "(x, out=None, scale=None, round_scale=False, output_mode=2):", "funcdef": "def"}, {"fullname": "linghe.utils.silu.silu_and_block_quant_backward_kernel", "modulename": "linghe.utils.silu", "qualname": "silu_and_block_quant_backward_kernel", "kind": "function", "doc": "\n", "signature": "(\tg_ptr,\tx_ptr,\tdx_ptr,\tdx_scale_ptr,\ttranspose_dx_ptr,\ttranspose_dx_scale_ptr,\tM,\tn: int,\tROUND: int):", "funcdef": "def"}, {"fullname": "linghe.utils.silu.triton_silu_and_block_quant_backward", "modulename": "linghe.utils.silu", "qualname": "triton_silu_and_block_quant_backward", "kind": "function", "doc": "\n", "signature": "(g, x, round_scale=False):", "funcdef": "def"}, {"fullname": "linghe.utils.silu.batch_weighted_silu_and_block_quant_forward_kernel", "modulename": "linghe.utils.silu", "qualname": "batch_weighted_silu_and_block_quant_forward_kernel", "kind": "function", "doc": "\n", "signature": "(\tx_ptr,\tweight_ptr,\tout_ptr,\tscale_ptr,\ttranspose_output_ptr,\ttranspose_scale_ptr,\tcount_ptr,\taccum_ptr,\tn: int,\tE: int,\tROUND: int,\tOUTPUT_MODE: int):", "funcdef": "def"}, {"fullname": "linghe.utils.silu.triton_batch_weighted_silu_and_block_quant_forward", "modulename": "linghe.utils.silu", "qualname": "triton_batch_weighted_silu_and_block_quant_forward", "kind": "function", "doc": "\n", "signature": "(\tx,\tweight,\tcounts,\tsplits=None,\tout=None,\tscale=None,\tround_scale=False,\toutput_mode=2):", "funcdef": "def"}, {"fullname": "linghe.utils.silu.batch_weighted_silu_and_block_quant_backward_kernel", "modulename": "linghe.utils.silu", "qualname": "batch_weighted_silu_and_block_quant_backward_kernel", "kind": "function", "doc": "\n", "signature": "(\tg_ptr,\tx_ptr,\tweight_ptr,\tcount_ptr,\taccum_ptr,\tdx_ptr,\tdx_scale_ptr,\ttranspose_dx_ptr,\ttranspose_dx_scale_ptr,\tdw_ptr,\tn: int,\tE: int,\tROUND: int):", "funcdef": "def"}, {"fullname": "linghe.utils.silu.triton_batch_weighted_silu_and_block_quant_backward", "modulename": "linghe.utils.silu", "qualname": "triton_batch_weighted_silu_and_block_quant_backward", "kind": "function", "doc": "\n", "signature": "(g, x, weight, counts, splits=None, round_scale=False):", "funcdef": "def"}, {"fullname": "linghe.utils.transpose", "modulename": "linghe.utils.transpose", "kind": "module", "doc": "Copyright (c) Ant Financial Service Group and its affiliates.
\n"}, {"fullname": "linghe.utils.transpose.deprecated_transpose_kernel", "modulename": "linghe.utils.transpose", "qualname": "deprecated_transpose_kernel", "kind": "function", "doc": "\n", "signature": "(x_ptr, t_ptr, M, N, H: int, W: int, EVEN: int):", "funcdef": "def"}, {"fullname": "linghe.utils.transpose.triton_depracated_transpose", "modulename": "linghe.utils.transpose", "qualname": "triton_depracated_transpose", "kind": "function", "doc": "\n", "signature": "(x):", "funcdef": "def"}, {"fullname": "linghe.utils.transpose.transpose_kernel", "modulename": "linghe.utils.transpose", "qualname": "transpose_kernel", "kind": "function", "doc": "\n", "signature": "(x_ptr, t_ptr, M, N, H: int, W: int, EVEN: int):", "funcdef": "def"}, {"fullname": "linghe.utils.transpose.transpose_dim_0_1_kernel", "modulename": "linghe.utils.transpose", "qualname": "transpose_dim_0_1_kernel", "kind": "function", "doc": "\n", "signature": "(x_ptr, t_ptr, B, M, b_stride, m_stride, N: int):", "funcdef": "def"}, {"fullname": "linghe.utils.transpose.triton_transpose", "modulename": "linghe.utils.transpose", "qualname": "triton_transpose", "kind": "function", "doc": "\n", "signature": "(x, dim0=None, dim1=None):", "funcdef": "def"}, {"fullname": "linghe.utils.transpose.transpose_and_pad_kernel", "modulename": "linghe.utils.transpose", "qualname": "transpose_and_pad_kernel", "kind": "function", "doc": "\n", "signature": "(x_ptr, t_ptr, M, N, P, H: int, W: int, EVEN: int):", "funcdef": "def"}, {"fullname": "linghe.utils.transpose.triton_transpose_and_pad", "modulename": "linghe.utils.transpose", "qualname": "triton_transpose_and_pad", "kind": "function", "doc": "\n", "signature": "(x, out=None, pad=True):", "funcdef": "def"}, {"fullname": "linghe.utils.transpose.batch_transpose_kernel", "modulename": "linghe.utils.transpose", "qualname": "batch_transpose_kernel", "kind": "function", "doc": "\n", "signature": "(xs_ptr, xts_ptr, M, N, H: int, W: int):", "funcdef": "def"}, {"fullname": "linghe.utils.transpose.triton_batch_transpose", "modulename": "linghe.utils.transpose", "qualname": "triton_batch_transpose", "kind": "function", "doc": "\n", "signature": "(xs, xts=None):", "funcdef": "def"}, {"fullname": "linghe.utils.transpose.batch_transpose_and_pad_kernel", "modulename": "linghe.utils.transpose", "qualname": "batch_transpose_and_pad_kernel", "kind": "function", "doc": "\n", "signature": "(x_ptr, t_ptr, count_ptr, accum_ptr, pad_accum_ptr, N, H: int, W: int):", "funcdef": "def"}, {"fullname": "linghe.utils.transpose.triton_batch_transpose_and_pad", "modulename": "linghe.utils.transpose", "qualname": "triton_batch_transpose_and_pad", "kind": "function", "doc": "\n", "signature": "(x, count_list, x_t=None, pad=True):", "funcdef": "def"}, {"fullname": "linghe.utils.transpose.configs", "modulename": "linghe.utils.transpose", "qualname": "configs", "kind": "variable", "doc": "\n", "default_value": "[<triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>]"}, {"fullname": "linghe.utils.transpose.opt_transpose_kernel", "modulename": "linghe.utils.transpose", "qualname": "opt_transpose_kernel", "kind": "function", "doc": "\n", "signature": "(x_ptr, t_ptr, M, N, D, H: int, W: int):", "funcdef": "def"}, {"fullname": "linghe.utils.transpose.triton_opt_transpose", "modulename": "linghe.utils.transpose", "qualname": "triton_opt_transpose", "kind": "function", "doc": "\n", "signature": "(x):", "funcdef": "def"}]; + + // mirrored in build-search-index.js (part 1) + // Also split on html tags. this is a cheap heuristic, but good enough. + elasticlunr.tokenizer.setSeperator(/[\s\-.;&_'"=,()]+|<[^>]*>/); + + let searchIndex; + if (docs._isPrebuiltIndex) { + console.info("using precompiled search index"); + searchIndex = elasticlunr.Index.load(docs); + } else { + console.time("building search index"); + // mirrored in build-search-index.js (part 2) + searchIndex = elasticlunr(function () { + this.pipeline.remove(elasticlunr.stemmer); + this.pipeline.remove(elasticlunr.stopWordFilter); + this.addField("qualname"); + this.addField("fullname"); + this.addField("annotation"); + this.addField("default_value"); + this.addField("signature"); + this.addField("bases"); + this.addField("doc"); + this.setRef("fullname"); + }); + for (let doc of docs) { + searchIndex.addDoc(doc); + } + console.timeEnd("building search index"); + } + + return (term) => searchIndex.search(term, { + fields: { + qualname: {boost: 4}, + fullname: {boost: 2}, + annotation: {boost: 2}, + default_value: {boost: 2}, + signature: {boost: 2}, + bases: {boost: 2}, + doc: {boost: 1}, + }, + expand: true + }); +})(); \ No newline at end of file diff --git a/linghe/__init__.py b/linghe/__init__.py index e69de29..8b13789 100644 --- a/linghe/__init__.py +++ b/linghe/__init__.py @@ -0,0 +1 @@ + diff --git a/linghe/facade/add.py b/linghe/facade/add.py index 945ad0e..c40aca4 100644 --- a/linghe/facade/add.py +++ b/linghe/facade/add.py @@ -9,10 +9,25 @@ class InplaceAddFunction(torch.autograd.Function): + """ + + """ @staticmethod - def forward(ctx, x, y): + def forward(ctx, x: torch.Tensor, y: torch.Tensor): return triton_inplace_add(x, y) @staticmethod def backward(ctx, grad_output): return grad_output, grad_output + + +def inplace_add(x: torch.Tensor, y: torch.Tensor): + """ + inplace add y to x with mix precise + Args: + x: to be updated + y: add to x + Returns: + return updated x tensor + """ + return InplaceAddFunction.apply(x, y) \ No newline at end of file diff --git a/linghe/facade/fp32_linear.py b/linghe/facade/fp32_gemm.py similarity index 66% rename from linghe/facade/fp32_linear.py rename to linghe/facade/fp32_gemm.py index d356856..8ab9ae9 100644 --- a/linghe/facade/fp32_linear.py +++ b/linghe/facade/fp32_gemm.py @@ -10,9 +10,12 @@ triton_fp32_gemm_for_update) -class FusedFp32GEMM(torch.autograd.Function): +class Fp32GEMM(torch.autograd.Function): + """ + + """ @staticmethod - def forward(ctx, input, weight): + def forward(ctx, input: torch.Tensor, weight: torch.Tensor): shape = input.shape assert len(shape) == 3 input = input.view(shape[0] * shape[1], shape[2]) @@ -32,9 +35,22 @@ def backward(ctx, grad_output): grad_output = grad_output.view(shape[0] * shape[1], shape[2]) input, weight = ctx.saved_tensors - dx = triton_fp32_gemm_for_backward(grad_output, weight, accum=False) + dx = triton_fp32_gemm_for_backward(grad_output, weight) dx = dx.view(*ctx.shape) dw = triton_fp32_gemm_for_update(grad_output, input) return dx, dw + + +def fp32_gemm(input: torch.Tensor, weight: torch.Tensor): + """ + gemm with bf16/fp16 inputs and float32 output, + currently used in MoE router gemm. + Args: + input: bf16/fp16 activation tensor + weight: bf16/fp16 weight tensor + Returns: + output of gemm + """ + return Fp32GEMM.apply(input, weight) \ No newline at end of file diff --git a/linghe/facade/hadamard_quant_linear.py b/linghe/facade/hadamard_quant_linear.py new file mode 100644 index 0000000..586f104 --- /dev/null +++ b/linghe/facade/hadamard_quant_linear.py @@ -0,0 +1,163 @@ +# -*- coding: utf-8 -*- +""" +Copyright (c) Ant Financial Service Group and its affiliates. +""" + +import math +from typing import Optional + +import torch + +from linghe.quant.hadamard import triton_hadamard_quant + + + +class _HadamardQuantLinear(torch.autograd.Function): + @staticmethod + def forward( + ctx, + input: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + hadamard_matrix: torch.Tensor + ): + ctx.input_requires_grad = input.requires_grad + ctx.weight_requires_grad = weight.requires_grad + ctx.bias_requires_grad = bias is not None and bias.requires_grad + + ctx.out_dtype = input.dtype + ctx.input_shape = input.shape + input = input.view(-1, input.shape[-1]) + + x_q, x_scale, xt_q, xt_scale = triton_hadamard_quant(input, hadamard_matrix) + w_q, w_scale, wt_q, wt_scale = triton_hadamard_quant(weight, hadamard_matrix) + + output = torch._scaled_mm(x_q, + w_q.t(), + scale_a=x_scale, + scale_b=w_scale, + out_dtype=ctx.out_dtype, + use_fast_accum=True + ) + + if bias is not None: + output += bias + + saved_tensors = [ + xt_q if ctx.weight_requires_grad else None, + xt_scale if ctx.weight_requires_grad else None, + wt_q if ctx.input_requires_grad else None, + wt_scale if ctx.input_requires_grad else None, + hadamard_matrix if ctx.weight_requires_grad or ctx.weight_requires_grad else None + ] + + ctx.save_for_backward(*saved_tensors) + out_shape = (*ctx.input_shape[0:-1], -1) + return output.view(out_shape) + + @staticmethod + def backward( + ctx, + output_grad: torch.Tensor, + ): + xt_q, xt_scale, wt_q, wt_scale, hadamard_matrix = ctx.saved_tensors + results = [None, None, None, None] + + output_grad = output_grad.view(-1, output_grad.shape[-1]) + + y_q, y_scale, yt_q, yt_scale = triton_hadamard_quant(output_grad, hadamard_matrix) + + dx = torch._scaled_mm(y_q, + wt_q.t(), + scale_a=y_scale, + scale_b=wt_scale, + out_dtype=ctx.out_dtype, + use_fast_accum=True + ) + + # calculate input grad and assign to results[0] + results[0] = dx.view(ctx.input_shape) + + # calculate weight grad and assign to results[1] + dw = torch._scaled_mm(yt_q, + xt_q.t(), + scale_a=yt_scale, + scale_b=xt_scale, + out_dtype=ctx.out_dtype, + use_fast_accum=True + ) + results[1] = dw + + if ctx.bias_requires_grad: + # calculate bias grad and assign to results[2] + results[2] = torch.sum(output_grad, dim=0) + + return tuple(results) + +class HadamardQuantLinear(torch.nn.Module): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None + ): + """ + a naive implementation of hadamard transformation and quantization + Args: + in_features: in feature number + out_features: out feature number + bias: whether use bias + device: weight device + dtype: weight dtype + impl: implementation of hadamard quantization + """ + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = torch.nn.parameter.Parameter( + torch.empty((out_features, in_features), device=device, + dtype=dtype)) + if bias: + self.bias = torch.nn.parameter.Parameter( + torch.empty(out_features, device=device, dtype=dtype)) + else: + self.bias = None + + size = 32 if 'H20' in torch.cuda.get_device_properties(0).name else 64 + data = self._hadamard_matrix(size, device=device, dtype=dtype, + norm=True) + self.hadamard_matrix = torch.nn.parameter.Parameter(data, + requires_grad=False) + self.reset_parameters() + + def _hadamard_matrix(self, size, device=None, dtype=None, norm=False): + assert 2 ** int(math.log2(size)) == size + m2 = torch.tensor([[1, 1], [1, -1]], device=device, dtype=torch.float32) + m = m2 + for _ in range(int(math.log2(size)) - 1): + m = torch.kron(m, m2) + if norm: + m = m / size ** 0.5 + if dtype is not None: + m = m.to(dtype) + return m + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if self.training: + return _HadamardQuantLinear.apply(input, self.weight, self.bias, + self.hadamard_matrix) + else: + output = input @ self.weight.t() + if self.bias is not None: + output = output + self.bias + return output + + def extra_repr(self) -> str: + return f"in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}" + + def reset_parameters(self): + self.weight.data.normal_(mean=0.0, std=0.02) + if self.bias is not None: + self.bias.data.zero_() diff --git a/linghe/facade/loss.py b/linghe/facade/loss.py index fff59dd..1fa7294 100644 --- a/linghe/facade/loss.py +++ b/linghe/facade/loss.py @@ -10,6 +10,9 @@ class SoftmaxCrossEntropyFunction(torch.autograd.Function): + """ + + """ @staticmethod def forward(ctx, logits, labels, inplace=False): shape = logits.shape @@ -38,7 +41,26 @@ def backward(ctx, grad_output): return grad, None, None, None +def softmax_cross_entropy(logits: torch.Tensor, labels: torch.Tensor, inplace: bool = False): + """ + softmax cross entropy + Args: + logits: logits tensor, shape [...,dim] + labels: labels tensor, shape [...] + inplace: update gradient in the `logits` tensor if True + + Returns: + per token loss + """ + assert logits.is_contiguous() + assert labels.is_contiguous() + return SoftmaxCrossEntropyFunction.apply(logits, labels, inplace) + + class GradScalingFunction(torch.autograd.Function): + """ + + """ @staticmethod def forward(ctx, x, coef=0.2): ctx.coef = coef diff --git a/linghe/facade/norm.py b/linghe/facade/norm.py index c5d31ad..435942f 100644 --- a/linghe/facade/norm.py +++ b/linghe/facade/norm.py @@ -10,6 +10,9 @@ class RMSNormFunction(torch.autograd.Function): + """ + + """ @staticmethod def forward(ctx, x, weight, eps=1e-6): output = triton_rms_norm_forward( @@ -37,17 +40,35 @@ def backward(ctx, dy): return dx, dw, None +def rms_norm(x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6): + """ + rms norm of x with weight + Args: + x: activation tensor + weight: weight tensor + eps: epsilon for RMS + + Returns: + rms output + """ + assert x.contiguous() + assert weight.contiguous() + return RMSNormFunction.apply(x, weight, eps) + class GroupNormGateFunction(torch.autograd.Function): + """ + + """ @staticmethod - def forward(ctx, x, gate, weight, eps=1e-6, group_size=4): + def forward(ctx, attn_output, gate, weight, eps=1e-6, group_size=4): output = triton_group_norm_gate_forward( - x, + attn_output, gate, weight.data, eps=eps, group_size=group_size ) - ctx.save_for_backward(x, gate, weight.data) + ctx.save_for_backward(attn_output, gate, weight.data) ctx.eps = eps ctx.group_size = group_size @@ -55,11 +76,11 @@ def forward(ctx, x, gate, weight, eps=1e-6, group_size=4): @staticmethod def backward(ctx, dy): - x, gate, weight = ctx.saved_tensors + attn_output, gate, weight = ctx.saved_tensors dx, dg, dw = triton_group_norm_gate_backward( dy, - x, + attn_output, gate, weight, ctx.eps, @@ -67,3 +88,23 @@ def backward(ctx, dy): ) return dx, dg, dw, None, None + + + +def group_norm_gate(attn_output: torch.Tensor, + gate: torch.Tensor, + weight: torch.Tensor, + eps: float = 1e-6, + group_size: int = 4): + """ + return group_rms_norm(transpose(attn_output, [0,1]), weight) * sigmoid(gate) + Args: + attn_output: output of core attn, shape [bs, length, n_heads, head_dim] + gate: gate tensor for attention output, shape [length, bs, dim] + weight: weight of RMS norm, shape [dim] + eps: epsilon for RMS + group_size: group size of group RMS norm + Returns: + output with shape [length, bs, dim] + """ + return GroupNormGateFunction.apply(attn_output, gate, weight, eps, group_size) \ No newline at end of file diff --git a/linghe/facade/rope.py b/linghe/facade/rope.py index f12d518..3719095 100644 --- a/linghe/facade/rope.py +++ b/linghe/facade/rope.py @@ -10,6 +10,9 @@ class QkNormHalfRopeFunction(torch.autograd.Function): + """ + + """ @staticmethod def forward(ctx, qkv, q_norm_weight, k_norm_weight, freqs, H=32, h=4, eps=1e-6): @@ -47,3 +50,35 @@ def backward(ctx, grad_q, grad_k, grad_v): transpose=True, interleave=True) return dqkv, dqw, dkw, None, None, None, None + + +def qk_norm_half_rope(qkv: torch.Tensor, + q_norm_weight: torch.Tensor, + k_norm_weight: torch.Tensor, + freqs: torch.Tensor, + H: int = 32, + h: int = 4, + eps: float = 1e-6): + """ + split qkv to q/k/v, apply qk norm and half rope to q/k, transpose q/k/v to flash-attention layout + Args: + qkv: QKV tensor with size of [S, B, dim], heads are interleaved + q_norm_weight: rms norm weight for query + k_norm_weight: rms norm weight for key + freqs: Freqs tensor based on half dim. + H: Number of attention heads. + h: Number of key/value heads. + eps: epsilon value for L2 normalization. + + Returns: + qo: shape [B, S, H, head_dim] + ko: shape [B, S, h, head_dim] + vo: shape [B, S, h, head_dim] + """ + return QkNormHalfRopeFunction.apply(qkv, + q_norm_weight, + k_norm_weight, + freqs, + H, + h, + eps) \ No newline at end of file diff --git a/linghe/facade/smooth_quant_linear.py b/linghe/facade/smooth_quant_linear.py new file mode 100644 index 0000000..fccbabb --- /dev/null +++ b/linghe/facade/smooth_quant_linear.py @@ -0,0 +1,154 @@ +# -*- coding: utf-8 -*- +""" +Copyright (c) Ant Financial Service Group and its affiliates. +""" + +from typing import Optional + +import torch + + +from linghe.quant.smooth import triton_smooth_quant, \ + triton_transpose_smooth_quant +from linghe.utils.transpose import triton_transpose_and_pad +from linghe.utils.reduce import triton_abs_max + +class _SmoothQuantLinear(torch.autograd.Function): + @staticmethod + def forward( + ctx, + input: torch.Tensor, + weight: torch.Tensor, + smooth_scale: torch.Tensor, + bias: Optional[torch.Tensor] + ): + ctx.input_requires_grad = input.requires_grad + ctx.weight_requires_grad = weight.requires_grad + ctx.bias_requires_grad = bias is not None and bias.requires_grad + + ctx.out_dtype = input.dtype + ctx.input_shape = input.shape + input = input.view(-1, input.shape[-1]) + + x_q, x_scale, x_maxs = triton_smooth_quant(input, 1 / smooth_scale) + w_q, w_scale, w_maxs = triton_smooth_quant(weight, smooth_scale) + + output = torch._scaled_mm(x_q, + w_q.t(), + scale_a=x_scale.view(-1, 1), + scale_b=w_scale.view(1, -1), + out_dtype=ctx.out_dtype, + use_fast_accum=True) + + if bias is not None: + output += bias + + saved_tensors = [ + x_q if ctx.weight_requires_grad else None, + x_scale if ctx.weight_requires_grad else None, + w_q if ctx.input_requires_grad else None, + w_scale if ctx.input_requires_grad else None, + smooth_scale if ctx.weight_requires_grad or ctx.weight_requires_grad else None + ] + + ctx.save_for_backward(*saved_tensors) + out_shape = (*ctx.input_shape[0:-1], -1) + return output.view(out_shape) + + @staticmethod + def backward( + ctx, + output_grad: torch.Tensor + ): + x_q, x_s, w_q, w_s, smooth_scale = ctx.saved_tensors + results = [None, None, None, None] + + output_grad = output_grad.view(-1, output_grad.shape[-1]) + + y_q, y_scale, y_maxs = triton_smooth_quant(output_grad, w_s) + + wt_q = triton_transpose_and_pad(w_q, pad=True) + dx = torch._scaled_mm(y_q, + wt_q.t(), + scale_a=y_scale.view(-1, 1), + scale_b=smooth_scale.view(1, -1), + out_dtype=ctx.out_dtype, + use_fast_accum=True) + + # calculate input grad and assign to results[0] + results[0] = dx.view(ctx.input_shape) + + # calculate weight grad and assign to results[1] + yt_q, yt_scale, yt_maxs = triton_transpose_smooth_quant(output_grad, x_s) + + xt_q = triton_transpose_and_pad(x_q, pad=True) + dw = torch._scaled_mm(yt_q, + xt_q.t(), + scale_a=yt_scale.view(-1, 1), + scale_b=1/smooth_scale.view(1, -1), + out_dtype=ctx.out_dtype, + use_fast_accum=True) + + results[1] = dw + + if ctx.bias_requires_grad: + # calculate bias grad and assign to results[2] + results[2] = torch.sum(output_grad, dim=0) + + return tuple(results) + + +class QuantLinear(torch.nn.Module): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = torch.nn.parameter.Parameter( + torch.empty((out_features, in_features), device=device, + dtype=dtype)) + if bias: + self.bias = torch.nn.parameter.Parameter( + torch.empty(out_features, device=device, dtype=dtype)) + else: + self.bias = None + + self.gap_step = 16 + self.decay_coef = 0.9 + self.smooth_scale = None + self.smooth_update_step = 0 + + self.reset_parameters() + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if self.training: + + if self.smooth_update_step % self.gap_step == 0: + input_maxs = triton_abs_max(input) + weight_maxs = triton_abs_max(self.weight.data) + self.smooth_scale = torch.sqrt(input_maxs * weight_maxs) + + output, smooth_scale = _SmoothQuantLinear.apply(input, + self.weight, + self.bias, + self.smooth_scale) + self.smooth_update_step += 1 + else: + output = input @ self.weight.t() + if self.bias is not None: + output = output + self.bias + return output + + def extra_repr(self) -> str: + return f"in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}" + + def reset_parameters(self): + self.weight.data.normal_(mean=0.0, std=0.02) + if self.bias is not None: + self.bias.data.zero_() diff --git a/linghe/facade/transpose.py b/linghe/facade/transpose.py index d3ecfaa..9d8de83 100644 --- a/linghe/facade/transpose.py +++ b/linghe/facade/transpose.py @@ -9,6 +9,9 @@ class TransposeDim01Function(torch.autograd.Function): + """ + + """ @staticmethod def forward(ctx, x): return triton_transpose(x, dim0=0, dim1=1) @@ -16,3 +19,15 @@ def forward(ctx, x): @staticmethod def backward(ctx, grad_output): return triton_transpose(grad_output, dim0=0, dim1=1) + + +def transpose_dim01(x): + """ + transpose a tensor with the first two dims, x.ndims should not greater than 4 + Args: + x: input tensor + + Returns: + a transposed tensor + """ + return TransposeDim01Function.apply(x) \ No newline at end of file diff --git a/linghe/gemm/blockwise_fp8_gemm.py b/linghe/gemm/blockwise_fp8_gemm.py new file mode 100644 index 0000000..da9416a --- /dev/null +++ b/linghe/gemm/blockwise_fp8_gemm.py @@ -0,0 +1,242 @@ +# -*- coding: utf-8 -*- +""" +Copyright (c) Ant Financial Service Group and its affiliates. +""" + +import torch +import triton +import triton.language as tl +from triton import Config + +fp8_gemm_configs = [ + Config({"BLOCK_SIZE_M": block_m, "BLOCK_SIZE_N": block_n}, + num_stages=num_stages, num_warps=8) + for block_m in [32, 64, 128] + for block_n in [32, 64, 128] + for num_stages in [3, 4, 5, 6] +] + + +# @triton.autotune(configs=fp8_gemm_configs, key=["N", "K"]) +@triton.jit +def fp8_gemm_bb_kernel( + a_ptr, + b_ptr, + c_ptr, + a_s_ptr, + b_s_ptr, + M, + N: tl.constexpr, + K: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + # a blockwise quantization, b blockwise quantization. + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + k = tl.cdiv(K, BLOCK_SIZE_K) + offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :] + # b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None] + b_ptrs = b_ptr + offs_n[:, None] * K + offs_k[None, :] + nb = K // BLOCK_SIZE_K + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for i in range(0, k): + a_s = tl.load(a_s_ptr + pid_m * nb + i) + b_s = tl.load(b_s_ptr + pid_n * nb + i) + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, + other=0.0) + b = tl.load(b_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, + other=0.0) + accumulator += tl.dot(a, tl.trans(b)) * (a_s * b_s) + # accumulator = tl.dot(a, tl.trans(b), accumulator) + # accumulator += (accumulators-accumulator) * scale + a_ptrs += BLOCK_SIZE_K + b_ptrs += BLOCK_SIZE_K + c = accumulator.to(c_ptr.dtype.element_ty) + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :] + mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + tl.store(c_ptrs, c, mask=mask) + + + +def triton_bb_fp8_gemm(a: torch.Tensor, + b: torch.Tensor, + a_s: torch.Tensor, + b_s: torch.Tensor, + out_dtype=torch.bfloat16, + block_size=128): + assert a.is_contiguous() and b.is_contiguous() + assert a_s.is_contiguous() and b_s.is_contiguous() + K = a.size(-1) + M = a.numel() // K + N = b.size(0) + c = torch.empty(M, N, dtype=out_dtype, device=a.device) + grid = lambda META: (triton.cdiv(M, META["BLOCK_SIZE_M"]), + triton.cdiv(N, META["BLOCK_SIZE_N"])) # noqa + + fp8_gemm_bb_kernel[grid](a, b, c, a_s, b_s, + M, N, K, + BLOCK_SIZE_K=block_size, + BLOCK_SIZE_M=block_size, + BLOCK_SIZE_N=block_size, + num_warps=8, + num_stages=4 + ) + return c + + + +@triton.autotune(configs=fp8_gemm_configs, key=["N", "K"]) +@triton.jit +def fp8_gemm_tb_kernel( + a_ptr, + b_ptr, + c_ptr, + a_s_ptr, + b_s_ptr, + M, + N: tl.constexpr, + K: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + # a tilewise quantization, b blockwise quantization. + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + k = tl.cdiv(K, BLOCK_SIZE_K) + offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :] + b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None] + a_s_ptrs = a_s_ptr + offs_m * k + b_s_ptrs = b_s_ptr + (offs_n // BLOCK_SIZE_K) * k + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for i in range(k): + # a = tl.load(a_ptrs) + # b = tl.load(b_ptrs) + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, + other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K, + other=0.0) + a_s = tl.load(a_s_ptrs) + b_s = tl.load(b_s_ptrs) + # accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :] + accumulators = tl.dot(a, b, accumulator) + accumulator += (accumulators - accumulator) * a_s[:, None] * b_s[None, + :] + a_ptrs += BLOCK_SIZE_K + b_ptrs += BLOCK_SIZE_K + a_s_ptrs += 1 + b_s_ptrs += 1 + c = accumulator.to(c_ptr.dtype.element_ty) + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :] + # tl.store(c_ptrs, c) + mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + tl.store(c_ptrs, c, mask=mask) + + + + +def triton_tb_fp8_gemm(a: torch.Tensor, + b: torch.Tensor, + a_s: torch.Tensor, + b_s: torch.Tensor, + out_dtype=torch.bfloat16, + block_size=128): + assert a.is_contiguous() and b.is_contiguous() + assert a_s.is_contiguous() and b_s.is_contiguous() + K = a.size(-1) + M = a.numel() // K + N = b.size(0) + c = torch.empty(M, N, dtype=out_dtype, device=a.device) + grid = lambda META: (triton.cdiv(M, META["BLOCK_SIZE_M"]), + triton.cdiv(N, META["BLOCK_SIZE_N"])) # noqa + + fp8_gemm_tb_kernel[grid](a, b, c, + a_s, b_s, + M, N, K, + block_size + ) + return c + + +@triton.autotune(configs=fp8_gemm_configs, key=["N", "K"]) +@triton.jit +def fp8_gemm_tt_kernel( + a_ptr, + b_ptr, + c_ptr, + a_s_ptr, + b_s_ptr, + M, + N: tl.constexpr, + K: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + # a and b all tilewise quantization. + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + k = tl.cdiv(K, BLOCK_SIZE_K) + offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :] + b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None] + a_s_ptrs = a_s_ptr + offs_m * k + b_s_ptrs = b_s_ptr + offs_n * k + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for i in range(k): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, + other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K, + other=0.0) + a_s = tl.load(a_s_ptrs) + b_s = tl.load(b_s_ptrs) + accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :] + a_ptrs += BLOCK_SIZE_K + b_ptrs += BLOCK_SIZE_K + a_s_ptrs += 1 + b_s_ptrs += 1 + + c = accumulator.to(c_ptr.dtype.element_ty) + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :] + mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + tl.store(c_ptrs, c, mask=mask) + + +def triton_tt_fp8_gemm(a: torch.Tensor, + b: torch.Tensor, + a_s: torch.Tensor, + b_s: torch.Tensor, + out_dtype=torch.bfloat16, + block_size=128): + assert a.is_contiguous() and b.is_contiguous() + assert a_s.is_contiguous() and b_s.is_contiguous() + K = a.size(-1) + M = a.numel() // K + N = b.size(0) + c = torch.empty(*a.size()[:-1], N, dtype=out_dtype, device=a.device) + grid = lambda META: (triton.cdiv(M, META["BLOCK_SIZE_M"]), + triton.cdiv(N, META["BLOCK_SIZE_N"])) # noqa + fp8_gemm_tt_kernel[grid](a, b, c, + a_s, b_s, + M, N, K, + block_size) + return c diff --git a/linghe/gemm/channelwise_fp8_gemm.py b/linghe/gemm/channelwise_fp8_gemm.py new file mode 100644 index 0000000..a4d978f --- /dev/null +++ b/linghe/gemm/channelwise_fp8_gemm.py @@ -0,0 +1,130 @@ +# -*- coding: utf-8 -*- +""" +Copyright (c) Ant Financial Service Group and its affiliates. +""" + +import torch +import triton +import triton.language as tl + + +# os.environ["TRITON_PRINT_AUTOTUNING"] = "1" + + +# fp8_gemm_configs = [ +# Config({"BLOCK_SIZE_K": block_k, "BLOCK_SIZE_M": block_m, "BLOCK_SIZE_N": block_n}, num_stages=num_stages, num_warps=num_warps) +# for block_k in [64, 128, 256] +# for block_m in [64, 128, 256] +# for block_n in [64, 128, 256] +# for num_stages in [2, 3, 4, 5] +# for num_warps in [4, 8, 16] +# # for num_stages in [3] +# # for num_warps in [8] +# ] + +# @triton.autotune(configs=fp8_gemm_configs, key=["M", "N", "K"]) +@triton.jit +def scaled_mm_kernel( + a_ptr, + b_ptr, + c_ptr, + a_scale_ptr, + b_scale_ptr, + N, + K, + ACCUM: tl.constexpr, + EVEN: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + k = tl.cdiv(K, BLOCK_SIZE_K) + offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) + offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :] + b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None] + a_scale = tl.load(a_scale_ptr + offs_m) + b_scale = tl.load(b_scale_ptr + offs_n) + + if ACCUM: + c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :] + accumulator = tl.load(c_ptrs).to(tl.float32) + a_s = 1 / tl.maximum(a_scale, 1e-30) + b_s = 1 / tl.maximum(b_scale, 1e-30) + accumulator = accumulator * a_s[:, None] * b_s[None, :] + else: + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + if EVEN: + for i in range(k): + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) + accumulator = tl.dot(a, b, accumulator) + a_ptrs += BLOCK_SIZE_K + b_ptrs += BLOCK_SIZE_K + else: + for i in range(k): + indices = i * BLOCK_SIZE_K + offs_k + a = tl.load(a_ptrs, mask=indices[None, :] < K) + b = tl.load(b_ptrs, mask=indices[:, None] < K) + accumulator = tl.dot(a, b, accumulator) + a_ptrs += BLOCK_SIZE_K + b_ptrs += BLOCK_SIZE_K + + accumulator = accumulator * a_scale[:, None] * b_scale[None, :] + accumulator = accumulator.to(c_ptr.dtype.element_ty) + c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :] + tl.store(c_ptrs, accumulator) + + +def triton_scaled_mm(a: torch.Tensor, + b: torch.Tensor, + a_scale: torch.Tensor, + b_scale: torch.Tensor, + out_dtype=torch.float32, + c=None, + accum=True): + """ + similar to torch._scaled_mm, support accumulating gemm output to c + and low precision output tensor + Args: + a: left fp8 tensor + b: right fp8 tensor, column-major + a_scale: fp32 scale of a + b_scale: fp32 scale of b + out_dtype: output tensor dtype + c: output tensor + accum: accumulate output on c if True + + Returns: + c: output tensor + """ + assert a.is_contiguous() and b.is_contiguous() + M, K = a.size() + N, K = b.size() + ACCUM = accum and c is not None + if c is None: + c = torch.empty(M, N, dtype=out_dtype, device=a.device) + BLOCK_SIZE_K = 128 + BLOCK_SIZE_M = 128 + BLOCK_SIZE_N = 256 + EVEN = K % BLOCK_SIZE_K == 0 + grid = lambda META: ( + M // META["BLOCK_SIZE_M"], N // META["BLOCK_SIZE_N"]) # noqa + scaled_mm_kernel[grid](a, b, c, + a_scale, + b_scale, + N, K, + ACCUM, + EVEN, + BLOCK_SIZE_K, + BLOCK_SIZE_M, + BLOCK_SIZE_N, + num_stages=3, + num_warps=8 + ) + + return c diff --git a/linghe/gemm/fp32_gemm.py b/linghe/gemm/fp32_gemm.py index e4bbeb9..8f44067 100644 --- a/linghe/gemm/fp32_gemm.py +++ b/linghe/gemm/fp32_gemm.py @@ -59,10 +59,19 @@ def fp32_gemm_kernel( tl.store(c_ptrs, c) -# a, bf16 -# b, bf16 -# c, fp32 + def triton_fp32_gemm(a: torch.Tensor, b: torch.Tensor): + """ + return fp32 gemm result with fp16/bf16 inputs, + it's mainly used for MoE router GEMM + and DO NOT suitable for large size GEMM + Args: + a: left matrix with fp16/bf16 precision + b: right matrix with fp16/bf16 precision + + Returns: + c: output with fp32 precision + """ assert a.is_contiguous() and b.is_contiguous() M, K = a.size() N, K = b.size() @@ -86,11 +95,11 @@ def triton_fp32_gemm(a: torch.Tensor, b: torch.Tensor): return c +# @triton.autotune(configs=fp32_gemm_configs, key=["M", "N", "K"]) @triton.jit -def scaled_fp32_gemm_kernel( +def fp32_gemm_for_backward_kernel( a_ptr, b_ptr, - scale_ptr, c_ptr, M, N: tl.constexpr, @@ -106,61 +115,64 @@ def scaled_fp32_gemm_kernel( offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) offs_k = tl.arange(0, BLOCK_SIZE_K) a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :] - b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None] + b_ptrs = b_ptr + offs_n[None, :] + offs_k[:, None] * N c = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for i in range(k): - a = tl.load(a_ptrs).to(tl.float32) + a = tl.load(a_ptrs) b = tl.load(b_ptrs).to(tl.float32) # c += tl.dot(a, b) c = tl.dot(a, b, c) a_ptrs += BLOCK_SIZE_K - b_ptrs += BLOCK_SIZE_K - - scale = tl.load( - scale_ptr + pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) - c *= scale[:, None] - + b_ptrs += BLOCK_SIZE_K * N offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :] tl.store(c_ptrs, c) -def triton_scaled_fp32_gemm(a: torch.Tensor, b: torch.Tensor, - scale: torch.Tensor): +def triton_fp32_gemm_for_backward(a: torch.Tensor, + b: torch.Tensor): + """ + mix precision gemm for backward, a@b.float() + Args: + a: input gradient, fp32 + b: gemm weight, bf16/fp16 + Returns: + c: gradient of activation + """ assert a.is_contiguous() and b.is_contiguous() M, K = a.size() - N, K = b.size() - c = torch.empty(M, N, dtype=torch.float32, device=a.device) + K, N = b.size() + c = torch.empty((M, N), dtype=b.dtype, device=b.device) grid = lambda META: (triton.cdiv(M, META["BLOCK_SIZE_M"]), triton.cdiv(N, META["BLOCK_SIZE_N"])) # noqa BLOCK_SIZE_K = 128 BLOCK_SIZE_M = 32 BLOCK_SIZE_N = 128 num_warps = 4 - num_stages = 3 - scaled_fp32_gemm_kernel[grid](a, b, scale, c, - M, N, K, - BLOCK_SIZE_K, - BLOCK_SIZE_M, - BLOCK_SIZE_N, - num_warps=num_warps, - num_stages=num_stages - ) + num_stages = 2 + fp32_gemm_for_backward_kernel[grid](a, b, c, + M, N, K, + BLOCK_SIZE_K, + BLOCK_SIZE_M, + BLOCK_SIZE_N, + num_warps=num_warps, + num_stages=num_stages + ) return c # @triton.autotune(configs=fp32_gemm_configs, key=["M", "N", "K"]) @triton.jit -def fp32_gemm_for_backward_kernel( +def fp32_gemm_for_update_kernel( a_ptr, b_ptr, c_ptr, M, N: tl.constexpr, K: tl.constexpr, - ACCUM: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, @@ -171,63 +183,62 @@ def fp32_gemm_for_backward_kernel( offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :] + a_ptrs = a_ptr + offs_m[None, :] + offs_k[:, None] * M b_ptrs = b_ptr + offs_n[None, :] + offs_k[:, None] * N - if ACCUM: - c = tl.load(c_ptr + offs_m[:, None] * N + offs_n[None, :]).to( - tl.float32) - else: - c = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - + c = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + # c = tl.load(c_ptr + offs_m[:, None] * N + offs_n[None, :]).to(tl.float32) for i in range(k): - a = tl.load(a_ptrs) + a = tl.trans(tl.load(a_ptrs)).to(tl.float32) b = tl.load(b_ptrs).to(tl.float32) # c += tl.dot(a, b) c = tl.dot(a, b, c) - a_ptrs += BLOCK_SIZE_K + a_ptrs += BLOCK_SIZE_K * M b_ptrs += BLOCK_SIZE_K * N + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :] tl.store(c_ptrs, c) -# a: router output, fp32 -# b: router weight, bf16, should be transposed before calculation -# c: dy of rms, bf16, shoule be accumlated -def triton_fp32_gemm_for_backward(a: torch.Tensor, b: torch.Tensor, - c: Optional[torch.Tensor] = None, - accum=False): +def triton_fp32_gemm_for_update(a: torch.Tensor, b: torch.Tensor): + """ + mix precision gemm for updaing weight + Args: + a: gradient of output, fp32 + b: input activation, bf16/fp16 + Returns: + c: gradient of weight + """ assert a.is_contiguous() and b.is_contiguous() - M, K = a.size() + K, M = a.size() K, N = b.size() - if c is None: - c = torch.empty((M, N), dtype=b.dtype, device=b.device) - accum = False + c = torch.empty((M, N), dtype=b.dtype, device=b.device) grid = lambda META: (triton.cdiv(M, META["BLOCK_SIZE_M"]), triton.cdiv(N, META["BLOCK_SIZE_N"])) # noqa BLOCK_SIZE_K = 128 BLOCK_SIZE_M = 32 BLOCK_SIZE_N = 128 num_warps = 4 - num_stages = 2 - fp32_gemm_for_backward_kernel[grid](a, b, c, - M, N, K, accum, - BLOCK_SIZE_K, - BLOCK_SIZE_M, - BLOCK_SIZE_N, - num_warps=num_warps, - num_stages=num_stages - ) + num_stages = 3 + fp32_gemm_for_update_kernel[grid](a, b, c, + M, N, K, + BLOCK_SIZE_K, + BLOCK_SIZE_M, + BLOCK_SIZE_N, + num_warps=num_warps, + num_stages=num_stages + ) return c -# @triton.autotune(configs=fp32_gemm_configs, key=["M", "N", "K"]) + @triton.jit -def fp32_gemm_for_update_kernel( +def scaled_fp32_gemm_kernel( a_ptr, b_ptr, + scale_ptr, c_ptr, M, N: tl.constexpr, @@ -242,18 +253,21 @@ def fp32_gemm_for_update_kernel( offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + offs_m[None, :] + offs_k[:, None] * M - b_ptrs = b_ptr + offs_n[None, :] + offs_k[:, None] * N + a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :] + b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None] c = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - # c = tl.load(c_ptr + offs_m[:, None] * N + offs_n[None, :]).to(tl.float32) for i in range(k): - a = tl.trans(tl.load(a_ptrs)).to(tl.float32) + a = tl.load(a_ptrs).to(tl.float32) b = tl.load(b_ptrs).to(tl.float32) # c += tl.dot(a, b) c = tl.dot(a, b, c) - a_ptrs += BLOCK_SIZE_K * M - b_ptrs += BLOCK_SIZE_K * N + a_ptrs += BLOCK_SIZE_K + b_ptrs += BLOCK_SIZE_K + + scale = tl.load( + scale_ptr + pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) + c *= scale[:, None] offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) @@ -261,13 +275,34 @@ def fp32_gemm_for_update_kernel( tl.store(c_ptrs, c) -# a: router output, fp32, should be transposed before calculation -# b: input of rms, bf16, should be transposed before calculation -def triton_fp32_gemm_for_update(a: torch.Tensor, b: torch.Tensor): +def triton_scaled_fp32_gemm(a: torch.Tensor, + b: torch.Tensor, + scale: torch.Tensor): + """ + c = (a*scale[:,None])*b + this kernel is used to fuse RMSNorm and quantization in MoE layer + native implementation: + y = rms_norm(x), + y_q = quantization(y), + router_logits = y@w + we can not fuse rms_norm and quantization + as we still need bf16 y for moe router gemm + fused implementation: + y_q, rms = quantization(rms_norm(x)) + router_logits = (x/rms)@y + so we need a scaled fp32 gemm kernel + Args: + a: activation tensor + b: weight tensor + scale: scale for activation tensor, 1/rms + + Returns: + + """ assert a.is_contiguous() and b.is_contiguous() - K, M = a.size() - K, N = b.size() - c = torch.empty((M, N), dtype=b.dtype, device=b.device) + M, K = a.size() + N, K = b.size() + c = torch.empty(M, N, dtype=torch.float32, device=a.device) grid = lambda META: (triton.cdiv(M, META["BLOCK_SIZE_M"]), triton.cdiv(N, META["BLOCK_SIZE_N"])) # noqa BLOCK_SIZE_K = 128 @@ -275,17 +310,20 @@ def triton_fp32_gemm_for_update(a: torch.Tensor, b: torch.Tensor): BLOCK_SIZE_N = 128 num_warps = 4 num_stages = 3 - fp32_gemm_for_update_kernel[grid](a, b, c, - M, N, K, - BLOCK_SIZE_K, - BLOCK_SIZE_M, - BLOCK_SIZE_N, - num_warps=num_warps, - num_stages=num_stages - ) + scaled_fp32_gemm_kernel[grid](a, b, + scale, + c, + M, N, K, + BLOCK_SIZE_K, + BLOCK_SIZE_M, + BLOCK_SIZE_N, + num_warps=num_warps, + num_stages=num_stages + ) return c + @triton.jit def scaled_fp32_gemm_for_update_kernel( a_ptr, @@ -309,7 +347,6 @@ def scaled_fp32_gemm_for_update_kernel( b_ptrs = b_ptr + offs_n[None, :] + offs_k[:, None] * N c = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - # c = tl.load(c_ptr + offs_m[:, None] * N + offs_n[None, :]).to(tl.float32) for i in range(k): scale = tl.load( scale_ptr + i * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)) @@ -326,11 +363,19 @@ def scaled_fp32_gemm_for_update_kernel( tl.store(c_ptrs, c) -# a: router output, fp32, should be transposed before calculation -# b: input of rms, bf16, should be transposed before calculation -# scale: 1/rms -def triton_scaled_fp32_gemm_for_update(a: torch.Tensor, b: torch.Tensor, +def triton_scaled_fp32_gemm_for_update(a: torch.Tensor, + b: torch.Tensor, scale: torch.Tensor): + """ + see triton_scaled_fp32_gemm + Args: + a: y + b: activation before RMS norm + scale: 1/rms + + Returns: + dw + """ assert a.is_contiguous() and b.is_contiguous() K, M = a.size() K, N = b.size() diff --git a/linghe/quant/block/block.py b/linghe/quant/block.py similarity index 85% rename from linghe/quant/block/block.py rename to linghe/quant/block.py index cdb9d16..cfb37fa 100644 --- a/linghe/quant/block/block.py +++ b/linghe/quant/block.py @@ -28,9 +28,20 @@ def block_quant_kernel(x_ptr, y_ptr, s_ptr, M, N, BLOCK_SIZE: tl.constexpr, tl.store(s_ptr + pid_m * n + pid_n, s) -def block_quant(x, +def triton_block_quant(x, block_size=128, round_scale=False): + """ + blockwise quantize x + Args: + x: input tensor + block_size: block wise + round_scale: whether round scale to power of 2 + + Returns: + y: quantized tensor, float8_e4m3fn + s: quantization scale, float32 + """ M, N = x.size() y = torch.empty((M, N), dtype=torch.float8_e4m3fn, device=x.device) s = x.new_empty(x.size(-2) // block_size, x.size(-1) // block_size, diff --git a/linghe/quant/block/__init__.py b/linghe/quant/block/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/linghe/quant/block/group.py b/linghe/quant/block/group.py deleted file mode 100644 index de5908c..0000000 --- a/linghe/quant/block/group.py +++ /dev/null @@ -1,107 +0,0 @@ -# -*- coding: utf-8 -*- -""" -Copyright (c) Ant Financial Service Group and its affiliates. -""" - -import torch -import triton -import triton.language as tl - - -@triton.jit -def group_quant_kernel(x_ptr, y_ptr, s_ptr, N, BLOCK_SIZE: tl.constexpr, - K: tl.constexpr, ROUND: tl.constexpr): - pid = tl.program_id(axis=0) - offs = pid * N + tl.arange(0, K * BLOCK_SIZE) - n = tl.cdiv(N, K * BLOCK_SIZE) - soffs = pid * n * K + tl.arange(0, K) - for i in range(n): - x = tl.load(x_ptr + offs).to(tl.float32) - x = tl.reshape(x, (K, BLOCK_SIZE), can_reorder=False) - s = tl.maximum(tl.max(tl.abs(x), 1) / 448.0, 1e-30) - if ROUND: - s = tl.exp2(tl.ceil(tl.log2(s))) - y = x / s[:, None] - y = y.to(y_ptr.dtype.element_ty) - y = tl.reshape(y, (K * BLOCK_SIZE,), can_reorder=False) - tl.store(y_ptr + offs, y) - tl.store(s_ptr + soffs, s) - offs += K * BLOCK_SIZE - soffs += K - - -def triton_group_quant(x, dtype=torch.float8_e4m3fn, group_size=128, - round_scale=False): - M, N = x.shape - K = 16 - assert N % group_size == 0 and N % (group_size * K) == 0 - assert x.is_contiguous() - - y = torch.empty((M, N), device=x.device, dtype=dtype) - s = torch.empty(M, N // group_size, device=x.device, dtype=torch.float32) - grid = (M,) # noqa - group_quant_kernel[grid](x, - y, - s, - N, - group_size, - K, - round_scale, - num_stages=5, - num_warps=4) - return y, s - - -@triton.jit -def persist_group_quant_kernel(x_ptr, y_ptr, s_ptr, N, BLOCK_SIZE: tl.constexpr, - B: tl.constexpr, K: tl.constexpr, - ROUND: tl.constexpr): - pid = tl.program_id(axis=0) - offs = pid * B * N + tl.arange(0, B)[:, None] * N + tl.arange(0, - K * BLOCK_SIZE)[ - None, :] - n = tl.cdiv(N, K * BLOCK_SIZE) - soffs = pid * B * n * K + tl.arange(0, B)[:, None] * n * K + tl.arange(0, - K)[ - None, :] - - for j in range(n): - x = tl.load(x_ptr + offs).to(tl.float32) - x = tl.reshape(x, (B, K, BLOCK_SIZE)) - - s = tl.maximum(tl.max(tl.abs(x), 2) / 448.0, 1e-30) - if ROUND: - s = tl.exp2(tl.ceil(tl.log2(s))) - y = x / s[:, :, None] - y = y.to(y_ptr.dtype.element_ty) - y = tl.reshape(y, (B, K * BLOCK_SIZE)) - tl.store(y_ptr + offs, y) - tl.store(s_ptr + soffs, s) - offs += K * BLOCK_SIZE - soffs += K - - -def triton_persist_group_quant(x, dtype=torch.float8_e4m3fn, group_size=128, - round_scale=False): - M, N = x.shape - device = x.device - K = 8 - B = 8 - assert N % group_size == 0 and N % (group_size * K) == 0 - assert x.is_contiguous() - - y = torch.empty((M, N), dtype=dtype, device=device) - s = torch.empty(M, N // group_size, device=x.device, dtype=torch.float32) - - grid = (M // B,) # noqa - persist_group_quant_kernel[grid](x, - y, - s, - N, - group_size, - B, - K, - round_scale, - num_stages=3, - num_warps=8) - return y, s diff --git a/linghe/quant/channel/channel.py b/linghe/quant/channel.py similarity index 88% rename from linghe/quant/channel/channel.py rename to linghe/quant/channel.py index 4968747..01a9571 100644 --- a/linghe/quant/channel/channel.py +++ b/linghe/quant/channel.py @@ -3,6 +3,7 @@ Copyright (c) Ant Financial Service Group and its affiliates. """ +from typing import Optional import torch import triton import triton.language as tl @@ -36,6 +37,16 @@ def row_quant_kernel(x_ptr, q_ptr, s_ptr, M, N, BLOCK_SIZE: tl.constexpr, def triton_row_quant(x, round_scale=False): + """ + rowwise quantize x + Args: + x: input x + round_scale: whether round scale to power of 2 + + Returns: + x_q: quantized tensor + x_scale: quantization scale + """ M, N = x.shape BLOCK_SIZE = max([N % x == 0 for x in [512, 1024, 2048, 4096, 8192]]) x_q = torch.empty((M, N), dtype=torch.float8_e4m3fn, device=x.device) @@ -73,9 +84,10 @@ def deprecated_tokenwise_row_quant_kernel(x_ptr, out_ptr, scale_ptr, M, offs += N -def triton_deprecated_tokenwise_row_quant(x, out=None, scale=None, - round_scale=False): - # row-wise read, row-wise write +def triton_deprecated_tokenwise_row_quant(x: torch.Tensor, + out: Optional[torch.Tensor] = None, + scale: Optional[torch.Tensor] = None, + round_scale: bool = False): M, N = x.shape device = x.device if out is None: @@ -113,6 +125,16 @@ def tokenwise_row_quant_kernel(x_ptr, out_ptr, scale_ptr, N: tl.constexpr, def triton_tokenwise_row_quant(x, out=None, scale=None, round_scale=False): + """ + rowwise quantize x with power of 2 dim size + Args: + x: input x + round_scale: whether round scale to power of 2 + + Returns: + out: quantized tensor + scale: quantization scale + """ # row-wise read, row-wise write M, N = x.shape device = x.device @@ -169,7 +191,18 @@ def transpose_row_quant_kernel(x_ptr, q_ptr, s_ptr, M, N, H: tl.constexpr, toffs += H -def triton_transpose_row_quant(x, side=0, round_scale=False): +def triton_transpose_row_quant(x, round_scale=False): + """ + transpose x and row quantize x + Args: + x: input x + round_scale: whether round scale to power of 2 + + Returns: + x_q: quantized tensor + x_scale: quantization scale + + """ M, N = x.shape H = 1024 W = 16 @@ -218,7 +251,6 @@ def channel_quant_forward(x, w): def channel_quant_backward(y, w): y_q, y_scale, w_q, w_scale = triton_channel_quant_nn(y, w) - # print(f'{y.shape=} {w.shape=} {y_q.shape=} {y_scale.shape=} {w_q.shape=} {w_scale.shape=}') output = torch._scaled_mm(y_q, w_q.t(), scale_a=y_scale, diff --git a/linghe/quant/channel/__init__.py b/linghe/quant/channel/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/linghe/quant/group.py b/linghe/quant/group.py new file mode 100644 index 0000000..9dec9b8 --- /dev/null +++ b/linghe/quant/group.py @@ -0,0 +1,66 @@ +# -*- coding: utf-8 -*- +""" +Copyright (c) Ant Financial Service Group and its affiliates. +""" + +import torch +import triton +import triton.language as tl + + +@triton.jit +def group_quant_kernel(x_ptr, y_ptr, s_ptr, N, BLOCK_SIZE: tl.constexpr, + K: tl.constexpr, ROUND: tl.constexpr): + pid = tl.program_id(axis=0) + offs = pid * N + tl.arange(0, K * BLOCK_SIZE) + n = tl.cdiv(N, K * BLOCK_SIZE) + soffs = pid * n * K + tl.arange(0, K) + for i in range(n): + x = tl.load(x_ptr + offs).to(tl.float32) + x = tl.reshape(x, (K, BLOCK_SIZE), can_reorder=False) + s = tl.maximum(tl.max(tl.abs(x), 1) / 448.0, 1e-30) + if ROUND: + s = tl.exp2(tl.ceil(tl.log2(s))) + y = x / s[:, None] + y = y.to(y_ptr.dtype.element_ty) + y = tl.reshape(y, (K * BLOCK_SIZE,), can_reorder=False) + tl.store(y_ptr + offs, y) + tl.store(s_ptr + soffs, s) + offs += K * BLOCK_SIZE + soffs += K + + +def triton_group_quant(x, + dtype=torch.float8_e4m3fn, + group_size=128, + round_scale=False): + """ + groupwise quantize x, group is in under rowwise format + Args: + x: input tensor + group_size: group wise + round_scale: whether round scale to power of 2 + + Returns: + y: quantized tensor, float8_e4m3fn + s: quantization scale, float32 + """ + M, N = x.shape + K = 16 + assert N % group_size == 0 and N % (group_size * K) == 0 + assert x.is_contiguous() + + y = torch.empty((M, N), device=x.device, dtype=dtype) + s = torch.empty(M, N // group_size, device=x.device, dtype=torch.float32) + grid = (M,) # noqa + group_quant_kernel[grid](x, + y, + s, + N, + group_size, + K, + round_scale, + num_stages=5, + num_warps=4) + return y, s + diff --git a/linghe/quant/hadamard.py b/linghe/quant/hadamard.py new file mode 100644 index 0000000..1bd77bc --- /dev/null +++ b/linghe/quant/hadamard.py @@ -0,0 +1,170 @@ +# -*- coding: utf-8 -*- +""" +Copyright (c) Ant Financial Service Group and its affiliates. +""" + +import torch +import triton +import triton.language as tl + + +@triton.jit +def hadamard_quant_row_kernel( + x_ptr, + hm_ptr, + x_q_ptr, + x_scale_ptr, + M, + N, + BLOCK_SIZE: tl.constexpr, + R: tl.constexpr, +): + pid = tl.program_id(0) + row_start = pid * R * BLOCK_SIZE + rows = row_start + tl.arange(0, R * BLOCK_SIZE) + mask_rows = rows < M + + hm = tl.load( + hm_ptr + tl.arange(0, BLOCK_SIZE)[:, None] * BLOCK_SIZE + tl.arange(0, + BLOCK_SIZE)[ + None, :]) + + max_val = tl.zeros((R * BLOCK_SIZE,), dtype=tl.float32) + 1.17e-38 + + num_col_blocks = tl.cdiv(N, BLOCK_SIZE) + for col_block in range(num_col_blocks): + col_start = col_block * BLOCK_SIZE + cols = col_start + tl.arange(0, BLOCK_SIZE) + mask_cols = cols < N + + offs = rows[:, None] * N + cols[None, :] + x = tl.load(x_ptr + offs, mask=mask_rows[:, None] & mask_cols[None, :], + other=0.0) + x_transformed = tl.dot(x, hm) + current_max = tl.max(tl.abs(x_transformed), axis=1) + max_val = tl.maximum(max_val, current_max) + + scale = max_val / 448.0 + tl.store(x_scale_ptr + rows, scale, mask=mask_rows) + s = 448.0 / tl.where(max_val > 0, max_val, 1.0) + + for col_block in range(num_col_blocks): + col_start = col_block * BLOCK_SIZE + cols = col_start + tl.arange(0, BLOCK_SIZE) + mask_cols = cols < N + + offs = rows[:, None] * N + cols[None, :] + x = tl.load(x_ptr + offs, mask=mask_rows[:, None] & mask_cols[None, :], + other=0.0) + x_transformed = tl.dot(x, hm) + quantized = (x_transformed * s[:, None]).to(x_q_ptr.dtype.element_ty) + tl.store(x_q_ptr + offs, quantized, + mask=mask_rows[:, None] & mask_cols[None, :]) + + +@triton.jit +def hadamard_quant_col_kernel( + x_ptr, + hm_ptr, + xt_q_ptr, + xt_scale_ptr, + M, + N, + BLOCK_SIZE: tl.constexpr, + R: tl.constexpr, +): + pid = tl.program_id(0) + col_start = pid * R * BLOCK_SIZE + cols = col_start + tl.arange(0, R * BLOCK_SIZE) + mask_cols = cols < N + + hm = tl.load( + hm_ptr + tl.arange(0, BLOCK_SIZE)[:, None] * BLOCK_SIZE + tl.arange(0, + BLOCK_SIZE)[ + None, :]) + + max_val = tl.zeros((R * BLOCK_SIZE,), dtype=tl.float32) + 1.17e-38 + + num_row_blocks = tl.cdiv(M, BLOCK_SIZE) + for row_block in range(num_row_blocks): + row_start = row_block * BLOCK_SIZE + rows = row_start + tl.arange(0, BLOCK_SIZE) + mask_rows = rows < M + + offs = rows[:, None] * N + cols[None, :] + x = tl.load(x_ptr + offs, mask=mask_rows[:, None] & mask_cols[None, :], + other=0.0) + x_transformed = tl.dot(hm, x) + current_max = tl.max(tl.abs(x_transformed), axis=0) + max_val = tl.maximum(max_val, current_max) + + scale = max_val / 448.0 + tl.store(xt_scale_ptr + cols, scale, mask=mask_cols) + s = 448.0 / tl.where(max_val > 0, max_val, 1.0) + + for row_block in range(num_row_blocks): + row_start = row_block * BLOCK_SIZE + rows = row_start + tl.arange(0, BLOCK_SIZE) + mask_rows = rows < M + + offs = rows[:, None] * N + cols[None, :] + x = tl.load(x_ptr + offs, mask=mask_rows[:, None] & mask_cols[None, :], + other=0.0) + x_transformed = tl.dot(hm, x) + quantized = (x_transformed * s[None, :]).to(xt_q_ptr.dtype.element_ty) + quantized_t = tl.trans(quantized) + store_offs = cols[:, None] * M + rows[None, :] + tl.store(xt_q_ptr + store_offs, quantized_t, + mask=mask_cols[:, None] & mask_rows[None, :]) + + +def triton_hadamard_quant(x, hm): + """ + apply hadamard transformation and then quantize transformed tensor + Args: + x: input tensor + hm: hamadard matrix + Returns: + x_q: rowwise quantized tensor of non-transposed x + x_scale: rowwise quantization scale of non-transposed x + xt_q: columnwise quantized tensor of transposed x + xt_scale: columnwise quantization scale of transposed x + """ + M, N = x.shape + device = x.device + BLOCK_SIZE = hm.size(0) + R = 1 + x_q = torch.empty((M, N), dtype=torch.float8_e4m3fn, device=device) + xt_q = torch.empty((N, M), dtype=torch.float8_e4m3fn, device=device) + x_scale = torch.empty((M, ), dtype=torch.float32, device=device) + xt_scale = torch.empty((N, ), dtype=torch.float32, device=device) + + grid_row = (triton.cdiv(M, R * BLOCK_SIZE),) + hadamard_quant_row_kernel[grid_row]( + x, + hm, + x_q, + x_scale, + M, + N, + BLOCK_SIZE, + R, + num_stages=6, + num_warps=4 + ) + + grid_col = (triton.cdiv(N, R * BLOCK_SIZE),) + hadamard_quant_col_kernel[grid_col]( + x, + hm, + xt_q, + xt_scale, + M, + N, + BLOCK_SIZE, + R, + num_stages=6, + num_warps=4 + ) + + return x_q, x_scale,xt_q, xt_scale diff --git a/linghe/quant/smooth.py b/linghe/quant/smooth.py new file mode 100644 index 0000000..4844511 --- /dev/null +++ b/linghe/quant/smooth.py @@ -0,0 +1,988 @@ +# -*- coding: utf-8 -*- +""" +Copyright (c) Ant Financial Service Group and its affiliates. +""" + +import torch +import triton +import triton.language as tl + +from linghe.tools.util import round_up +from linghe.utils.transpose import triton_transpose_and_pad + + +@triton.jit +def tokenwise_smooth_quant_kernel(x_ptr, q_ptr, ss_ptr, qs_ptr, max_ptr, + M, T, + N: tl.constexpr, + W: tl.constexpr, + EVEN: tl.constexpr, + REVERSE: tl.constexpr, + ROUND: tl.constexpr, + CALIBRATE: tl.constexpr): + pid = tl.program_id(axis=0) + # row-wise read, row-wise write + smooth_scale = tl.load(ss_ptr + tl.arange(0, N))[None, :] + if not REVERSE: + smooth_scale = 1.0 / smooth_scale + + if CALIBRATE: + output_maxs = tl.zeros((W, N), dtype=tl.float32) + for i in range(T): + indices = pid * W * T + i * W + tl.arange(0, W) + if EVEN: + x = tl.load(x_ptr + pid * W * T * N + i * N * W + tl.arange(0, W)[:, + None] * N + tl.arange( + 0, N)[None, :]).to( + tl.float32) + else: + x = tl.load(x_ptr + pid * W * T * N + i * N * W + tl.arange(0, W)[:, + None] * N + tl.arange( + 0, N)[None, :], + mask=indices[:, None] < M).to( + tl.float32) + if CALIBRATE: + output_maxs = tl.maximum(tl.abs(x), output_maxs) + x *= smooth_scale + x_max = tl.max(tl.abs(x), axis=1) + scale = tl.maximum(x_max / 448.0, 1e-30) + if ROUND: + scale = tl.exp2(tl.ceil(tl.log2(scale))) + if EVEN: + tl.store(qs_ptr + pid * W * T + i * W + tl.arange(0, W), scale, ) + else: + tl.store(qs_ptr + pid * W * T + i * W + tl.arange(0, W), scale, + mask=indices < M) + + x /= scale[:, None] + xq = x.to(q_ptr.dtype.element_ty) + if EVEN: + tl.store(q_ptr + pid * W * T * N + i * N * W + tl.arange(0, W)[:, + None] * N + tl.arange( + 0, + N)[ + None, :], + xq) + else: + tl.store(q_ptr + pid * W * T * N + i * N * W + tl.arange(0, W)[:, + None] * N + tl.arange( + 0, + N)[ + None, :], + xq, + mask=indices[:, None] < M) + if CALIBRATE: + output_maxs = tl.max(output_maxs, 0) + tl.store(max_ptr + pid * N + tl.arange(0, N), output_maxs) + + +@triton.jit +def blockwise_smooth_quant_kernel(x_ptr, q_ptr, ss_ptr, qs_ptr, max_ptr, + M, + N, + H: tl.constexpr, + W: tl.constexpr, + EVEN: tl.constexpr, + REVERSE: tl.constexpr, + ROUND: tl.constexpr, + CALIBRATE: tl.constexpr): + pid = tl.program_id(axis=0) + # row-wise read, row-wise write + offs = pid * W * N + tl.arange(0, W)[:, None] * N + tl.arange(0, H)[None, :] + soffs = tl.arange(0, H) + x_max = tl.zeros((W,), dtype=tl.float32) + n = tl.cdiv(N, H) + for i in range(n): + smooth_scale = tl.load(ss_ptr + soffs) + if EVEN: + x = tl.load(x_ptr + offs).to(tl.float32) + else: + x = tl.load(x_ptr + offs, + mask=pid * W + tl.arange(0, W)[:, None] < M).to( + tl.float32) + if CALIBRATE: + output_maxs = tl.max(x.abs(), 0) + tl.store(max_ptr + pid * N + i * H + tl.arange(0, H), output_maxs) + if REVERSE: + x = x * smooth_scale + else: + x = x / smooth_scale + x_max = tl.maximum(tl.max(tl.abs(x), axis=1), x_max) + offs += H + soffs += H + + scale = tl.maximum(x_max / 448, 1e-30) + if ROUND: + scale = tl.exp2(tl.ceil(tl.log2(scale))) + + tl.store(qs_ptr + pid * W + tl.arange(0, W), scale, + mask=pid * W + tl.arange(0, W) < M) + + s = (1.0 / scale)[:, None] + + offs = pid * W * N + tl.arange(0, W)[:, None] * N + tl.arange(0, H)[None, :] + soffs = tl.arange(0, H) + for i in range(n): + smooth_scale = tl.load(ss_ptr + soffs) + if EVEN: + x = tl.load(x_ptr + offs) + else: + x = tl.load(x_ptr + offs, + mask=pid * W + tl.arange(0, W)[:, None] < M) + + if REVERSE: + xq = (x.to(tl.float32) * smooth_scale * s).to( + q_ptr.dtype.element_ty) + else: + xq = (x.to(tl.float32) / smooth_scale * s).to( + q_ptr.dtype.element_ty) + + if EVEN: + tl.store(q_ptr + offs, xq) + else: + # tl.store(q_ptr+offs, xq, mask=(i*H+tl.arange(0, H)[None,:]