diff --git a/.gitignore b/.gitignore index aaa3d00..2f5fb50 100644 --- a/.gitignore +++ b/.gitignore @@ -31,7 +31,7 @@ __pycache__/ # Distribution / packaging .Python -build/ +docs/build/ develop-eggs/ dist/ downloads/ @@ -68,5 +68,4 @@ pip-delete-this-directory.txt *.pyc *.json *.jsonl -*_ignore.py -.idea \ No newline at end of file +.idea diff --git a/README.md b/README.md index e7aedcd..08e3b3b 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ -

Linghe

+

linghe

Logo @@ -20,16 +20,16 @@ ## *News or Update* 🔥 --- -- [2025/07] We implement multiple kernels for fp8 training with `Megatron-LM` blockwise quantization. +- [2025/07] We implement multiple kernels for FP8 training with `Megatron-LM` blockwise quantization. ## Introduction --- -Our repo, FLOPS, is designed for LLM training, especially for MoE training with fp8 quantizaiton. It provides 3 main categories of kernels: +Our repo, linghe, is designed for LLM training, especially for MoE training with FP8 quantizaiton. It provides 2 main categories of kernels: - **Fused quantization kernels**: fuse quantization with previous layer, e.g., RMS norm and Silu. -- **Memory-friendly kernels**: use dtype cast in kernels instead of casting out kernels, e.g., softmax cross entropy and moe router gemm. -- **Other fused kernels**: fuse multiple IO-itensive operations, e.g., ROPE with qk-norm and transpose, permute and padding, group RMS norm with sigmoid gate. +- **Memory-efficiency kernels**: fuse multiple IO-itensive operations, e.g., ROPE with qk-norm. +- **Implementation-optimized kernels**: use efficient triton implementation, e.g., routing map padding instead of activation padding. ## Benchmark @@ -37,25 +37,26 @@ Our repo, FLOPS, is designed for LLM training, especially for MoE training with We benchmark on H800 with batch size 8192, hidden size 2048, num experts 256, activation experts 8. | kernel | baseline(us) | linghe(us) | speedup | -|--------|--------------|-----------|---------| -| RMSNorm+Quantization(forward) | 159.3 us | 72.4 us | 2.2 | -| Split+qk-norm+rope+transpose(forward) | 472 us | 59.1 us | 7.99 | -| Split+qk-norm+rope+transpose(backward) | 645 us | 107.5 us | 6.0 | -| Fp32 router gemm(forward) | 242.3 us | 61.6 us | 3.931 | -| Fp32 router gemm(backward) | 232.7 us | 78.1 us | 2.979 | -| Permute with padded indices | 388 us | 229.4 us | 1.69 | -| Unpermute with padding indices | 988.6 us | 806.9 us | 1.23 | -| Batch Silu+quantization(forward) | 6241.7 us | 1181.7 us | 5.28 | -| Batch Silu+quantization(backward) | 7147.7 us | 2317.9 us | 3.08 | -| Silu+quantization(forward) | 144.9 us | 58.2 us | 2.48 | -| Silu+quantization(backward) | 163.4 us | 74.2 us | 2.2 | -| fused linear gate(forward) | 160.4 us | 46.9 us | 3.42 | -| fused linear gate(backward) | 572.9 us | 81.1 us | 7.06 | -| Cross entropy(forward) | 2780.8 us | 818.2 us | 3.4 | -| Cross entropy(backward) | 7086.3 us | 1781.0 us | 3.98 | -| batch grad norm | 1733.7 us | 1413.7 us | 1.23 | -| Batch count zero | 4997.9 us | 746.8 us | 6.69 | - +|--------|--------------|------------|---------| +| RMSNorm+Quantization(forward) | 159.3 us | 72.4 us | 2.2 | +| Split+qk-norm+rope+transpose(forward) | 472 us | 59.1 us | 7.99 | +| Split+qk-norm+rope+transpose(backward) | 645 us | 107.5 us | 6.0 | +| Fp32 router gemm(forward) | 242.3 us | 61.6 us | 3.931 | +| Fp32 router gemm(backward) | 232.7 us | 78.1 us | 2.979 | +| Permute with padded indices | 388 us | 229.4 us | 1.69 | +| Unpermute with padding indices | 988.6 us | 806.9 us | 1.23 | +| Batch Silu+quantization(forward) | 6241.7 us | 1181.7 us | 5.28 | +| Batch Silu+quantization(backward) | 7147.7 us | 2317.9 us | 3.08 | +| Silu+quantization(forward) | 144.9 us | 58.2 us | 2.48 | +| Silu+quantization(backward) | 163.4 us | 74.2 us | 2.2 | +| fused linear gate(forward) | 160.4 us | 46.9 us | 3.42 | +| fused linear gate(backward) | 572.9 us | 81.1 us | 7.06 | +| Cross entropy(forward) | 2780.8 us | 818.2 us | 3.4 | +| Cross entropy(backward) | 7086.3 us | 1781.0 us | 3.98 | +| batch grad norm | 1733.7 us | 1413.7 us | 1.23 | +| Batch count zero | 4997.9 us | 746.8 us | 6.69 | + +Other benchmark results can be obtained by running scripts in tests and benchmark folders. ## Examples --- @@ -65,4 +66,4 @@ Examples can be found in tests. ## Api Reference --- -Please refer to [API doc](docs/api.md) \ No newline at end of file +Please refer to [API](https://inclusionai.github.io/linghe/) \ No newline at end of file diff --git a/asserts/linghe.png b/asserts/linghe.png new file mode 100644 index 0000000..c1778e4 Binary files /dev/null and b/asserts/linghe.png differ diff --git a/build.sh b/build.sh index 904a121..b123d1c 100644 --- a/build.sh +++ b/build.sh @@ -2,4 +2,6 @@ rm -rf build && rm -rf dist && rm -rf linghe.egg-info && python setup.py develop && -python setup.py bdist_wheel && +python setup.py bdist_wheel + +# pdoc --output-dir docs -d google --no-include-undocumented --no-search --no-show-source linghe \ No newline at end of file diff --git a/docs/api.md b/docs/api.md deleted file mode 100644 index 000bfc2..0000000 --- a/docs/api.md +++ /dev/null @@ -1,212 +0,0 @@ -# API Reference - - -``` -linghe.utils.norm.triton_rms_norm_and_block_quant_forward(x, weight, eps:Optional[float]=1e-6, out:Optional[torch.Tensor]=None, scale:Optional[torch.Tensor]=None, rms:Optional[torch.Tensor]=None, round_scale: Optional[bool]=False, output_mode:Optional[int]=2) -``` - -Computes the forward pass of RMSNorm and block quantization. - -**Parameters:** -- x(*torch.Tensor*) - Input tensor. [M, N] -- weight(*torch.Tensor*) - RMSNorm weight. [N] -- eps(*float*) - epsilon value for L2 normalization. -- round_scale(*bool*) - Set whether to force power of 2 scales. -- rms(*torch.Tensor*) - Reciprocal of the root mean square of the input calculated over the last dimension.[N] -- output_mode - (*int*, {0, 1, 2}, default = 2) 0 only output non-transpose tensor, 1 only output transposed tensor, 2 return both. - ---- - -**` -Class linghe.facade.rope.QkNormHalfRopeFunction -`** - -``` -forward(qkv:, q_norm_weight, k_norm_weight, freqs, H, h, eps:Optional[float]=1e-6) -``` -Split qkv, and apply L2 nrom and ROPE on q and k. - -**Parameters:** -- qkv(*torch.Tensor*) - QKV tensor with size of [S, B, dim] -- freqs(*torch.Tensor*) - Freqs matrix based on half dim. -- H(*int*) - Number of attention heads. -- h(*int*) - Number of query groups. -- eps(*float*) - epsilon value for L2 normalization. - -``` -backward(grad_q, grad_k, grad_v) -``` -**Parameters:** -- grad_q(*torch.Tensor*) Grad of q tensor. -- grad_k(*torch.Tensor*) Grad of k tensor. -- grad_v(*torch.Tensor*) Gard of v tensor. - ---- - -**` -Class linghe.facade.fp32_linear.FusedFp32GEMM -`** - -Optimized fp32 gemm in router gate function. Convert bf16 input and weight to float32 during the gemm operation. - -``` -forward(input, weight) -``` -**Parameters:** -- input(*torch.Tensor*) - Input tensor with [B, S, dim], dtype of bf16. -- weight(*torch.Tensor*) - Weight tensor of router. - -``` -backward(grad_output) -``` -**Parameters:** -- grad_output(*torch.Tensor*) - Gradient of the activation. - ---- - -``` -linghe.utils.gather.triton_permute_with_mask_map(inp, scale, probs, row_id_map, num_out_tokens, contiguous, tokens_per_expert) -``` -Permute the tokens and probs based on the routing map. Index indicates row index of the output tensor(-1 means not selected). Perform well even when inp.size(0) < expert padding number, do not need extra explict padding. - -**Parameters:** -- inp(*torch.Tensor*) - Input hidden.[num_tokens, hidden_size] -- scale(*torch.Tensor*) - [num_tokens, scale_size] -- prob(*torch.Tensor*) - [num_tokens] Router prob. -- row_id_map(*torch.Tensor*) - [n_experts, num_tokens] Index indicates row index of the output tensor. -- num_out_tokens(*int*) - Output token count, including padding tokens. -- contiguous(*bool*) - Whether indices in row_id_map is contiguous, should be False if padded. -- token_per_expert(bool) - [num_experts] Token count per expert, non-blocking cuda tensor. - ---- - -``` -linghe.utils.scatter.triton_unpermute_with_mask_map(grad, row_id_map, probs) -``` -Unpermute a tensor with permuted tokens with router mapping. - -**Parameters:** -- inp(*torch.Tensor*) - [num_tokens, hidden_size] Permuted tokens. -- row_id_map(*torch.Tensor*) - [n_experts, num_tokens] Routing map to unpermute the tokens. -- prob(*torch.Tensor*) - [num_out_tokens] Permuted probs. - ---- - -``` -linghe.util.silu.triton_silu_and_block_quant_forward(x, out:Optional[torch.Tensor]=None, scale:Optional[torch.Tensor]=None, round_scale:Optional[bool]=False, output_mode:Optional[int]=2) -``` - -Applies the forward pass of Sigmoid Linear Unit(SiLU) element-wise and block quant.(used in shared expert layers.) - -**Parameters:** -- x(*torch.Tensor*) - Input tensor to be quanted. -- round_scale(*bool*) - Set whether to force power of 2 scales. -- output_mode - (*int*, {0, 1, 2}, default = 2) 0 only output non-transpose tensor, 1 only output transposed tensor, 2 return both. - ---- - -``` -linghe.util.silu.triton_silu_and_block_quant_backward(g, x, round_scale:Optional[bool]=False) -``` -**Parameters:** -- g(*torch.Tensor*) - Gradient tensor to be quanted. -- x(*torch.Tensor*) - Input tensor. -- round_scale(*bool*) - Set whether to force power of 2 scales. Default to False. - ---- - -``` -linghe.util.silu.triton_batch_weighted_silu_and_block_quant_forward(x, weight, counts, splits:Optional[List]=None ,out:Optional[torch.Tensor]=None, scale:Optional[torch.Tensor]=None, round_scale:Optional[bool]=False, output_mode:Optional[int]=2) -``` - -Fused op for batched weighted SiLU and block quant. - -**Parameters:** -- x(*torch.Tensor*) - Input tensor. -- weight(*torch.Tensor*) - Permuted probs -- couts(*torch.Tensor*) - Tokens per expert cuda tensor. -- splits(*List[int]*) - List of tokens per expert. If compute in batch mode should not be None. -- output_mode - (*int*, {0, 1, 2}, default = 2) 0 only output non-transpose tensor, 1 only output transposed tensor, 2 return both. - ---- - -``` -linghe.util.silu.triton_batch_weighted_silu_and_block_quant_backward(g, x, weight, counts, splits:Optional[List]=None, round_scale:Optional[bool]=False) -``` -**Parameters:** -- g(*torch.Tensor*) - Input gradient tensor. -- x(*torch.Tensor*) - Input tensor. -- weight(*torch.Tensor*) - Permuted probs -- couts(*torch.Tensor*) - Tokens per expert cuda tensor. -- splits(*List[int]*) - List of tokens per expert. If compute in batch mode should not be None. - ---- - -**` -Class linghe.facade.loss.SoftmaxCrossEntropyFunction -`** - -Prallel version of SoftmaxCrossEntropy. - -``` -forward(logits, labels, inplace:Optional[bool]=False) -``` - -Fast impl of softmax cross entropy. - -**Parameters:** -- logits(*torch.Tensor*) - Input logits. -- labels(*torch.Tensor*) - Input labels. -- inplace(*bool*) - Flag save for backward, whether logits ptr should replaced by grads tensor ptr. - -``` -backward(grad_output) -``` - -**Parameters:** -- grad_output(*torch.Tensor*) - Gradients tensor. - ---- - -``` -linghe.util.reduce.triton_batch_sum_with_ord(xs, ord:Optional[int]=2) -``` -Square sum the gards of all the experts. All the experts grads are applied simultaneously. - -**Parameters:** -- xs(*List[torch.Tensor]*) - Grads lists. -- ord(*int*) - Sum type. 1 for abs add and 2 for square add. - ---- - -``` -linghe.util.reduce.triton_batch_count_zero(xs) -``` -Prallel cout zeros in all the given grads lists. - -**Parameters:** -- xs(*List[torch.Tensor]*) - Grads lists. - ---- - -**` -Class linghe.facade.norm.GroupNormGateFunction -`** -Fused operation of group RMSNorm and sigmoid gate function. - -``` -forward(x, gate, weight, eps:Optional[float]=1e-6, group_size:Optional[int]=4) -``` -Note that the output shape is transposed [S, B, dim] - -**Parameters:** - -- x(*torch.Tensor*) - [B, S, dim] Input tensor. -- gate(*torch.Tensor*) - [S, B, dim] -- weight(*torch.Tensor*) - [dim] - -``` -backward(grad) -``` -**Parameters:** -- grad(*torch.Tensor*) - [S, B, dim] Grads of input tensor. diff --git a/docs/index.html b/docs/index.html new file mode 100644 index 0000000..514a509 --- /dev/null +++ b/docs/index.html @@ -0,0 +1,7 @@ + + + + + + + diff --git a/docs/linghe.html b/docs/linghe.html new file mode 100644 index 0000000..a7cef32 --- /dev/null +++ b/docs/linghe.html @@ -0,0 +1,52 @@ + + + + + + + linghe API documentation + + + + + + + + + +
+
+

+linghe

+ + + + + +
+
+ + \ No newline at end of file diff --git a/docs/linghe.png b/docs/linghe.png deleted file mode 100644 index fc1f51c..0000000 Binary files a/docs/linghe.png and /dev/null differ diff --git a/docs/linghe/facade.html b/docs/linghe/facade.html new file mode 100644 index 0000000..e50704a --- /dev/null +++ b/docs/linghe/facade.html @@ -0,0 +1,61 @@ + + + + + + + linghe.facade API documentation + + + + + + + + + +
+
+

+linghe.facade

+ + + + + +
+
+ + \ No newline at end of file diff --git a/docs/linghe/facade/add.html b/docs/linghe/facade/add.html new file mode 100644 index 0000000..57a6c73 --- /dev/null +++ b/docs/linghe/facade/add.html @@ -0,0 +1,87 @@ + + + + + + + linghe.facade.add API documentation + + + + + + + + + +
+
+

+linghe.facade.add

+ +

Copyright (c) Ant Financial Service Group and its affiliates.

+
+ + + + +
+
+
+ + def + inplace_add(x: torch.Tensor, y: torch.Tensor): + + +
+ + +

inplace add y to x with mix precise

+ +
Arguments:
+ +
    +
  • x: to be updated
  • +
  • y: add to x
  • +
+ +
Returns:
+ +
+

return updated x tensor

+
+
+ + +
+
+ + \ No newline at end of file diff --git a/docs/linghe/facade/fp32_gemm.html b/docs/linghe/facade/fp32_gemm.html new file mode 100644 index 0000000..b4858e4 --- /dev/null +++ b/docs/linghe/facade/fp32_gemm.html @@ -0,0 +1,88 @@ + + + + + + + linghe.facade.fp32_gemm API documentation + + + + + + + + + +
+
+

+linghe.facade.fp32_gemm

+ +

Copyright (c) Ant Financial Service Group and its affiliates.

+
+ + + + +
+
+
+ + def + fp32_gemm(input: torch.Tensor, weight: torch.Tensor): + + +
+ + +

gemm with bf16/fp16 inputs and float32 output, +currently used in MoE router gemm.

+ +
Arguments:
+ +
    +
  • input: bf16/fp16 activation tensor
  • +
  • weight: bf16/fp16 weight tensor
  • +
+ +
Returns:
+ +
+

output of gemm

+
+
+ + +
+
+ + \ No newline at end of file diff --git a/docs/linghe/facade/fp32_linear.html b/docs/linghe/facade/fp32_linear.html new file mode 100644 index 0000000..09c0fd7 --- /dev/null +++ b/docs/linghe/facade/fp32_linear.html @@ -0,0 +1,130 @@ + + + + + + + linghe.facade.fp32_linear API documentation + + + + + + + + + +
+
+

+linghe.facade.fp32_linear

+ +

Copyright (c) Ant Financial Service Group and its affiliates.

+
+ + + + +
+
+
+ + class + FusedFp32GEMM(torch.autograd.function.Function): + + +
+ + +

gemm with bf16/fp16 inputs and float32 output, +currently used in MoE router gemm.

+
+ + +
+
+
@staticmethod
+ + def + forward(ctx, input: torch.Tensor, weight: torch.Tensor): + + +
+ + +

gemm forward with bf16/fp16 inputs and float32 output.

+ +
Arguments:
+ +
    +
  • ctx:
  • +
  • input: bf16/fp16 act tensor
  • +
  • weight: bf16/fp16 weight tensor
  • +
+ +
Returns:
+ +
+

output of gemm

+
+
+ + +
+
+
+
@staticmethod
+ + def + backward(ctx, grad_output): + + +
+ + +

backward

+
+ + +
+
+
+ + \ No newline at end of file diff --git a/docs/linghe/facade/hadamard_quant_linear.html b/docs/linghe/facade/hadamard_quant_linear.html new file mode 100644 index 0000000..530195c --- /dev/null +++ b/docs/linghe/facade/hadamard_quant_linear.html @@ -0,0 +1,190 @@ + + + + + + + linghe.facade.hadamard_quant_linear API documentation + + + + + + + + + +
+
+

+linghe.facade.hadamard_quant_linear

+ +

Copyright (c) Ant Financial Service Group and its affiliates.

+
+ + + + +
+
+
+ + class + HadamardQuantLinear(torch.nn.modules.module.Module): + + +
+ + +

Base class for all neural network modules.

+ +

Your models should also subclass this class.

+ +

Modules can also contain other Modules, allowing them to be nested in +a tree structure. You can assign the submodules as regular attributes::

+ +
import torch.nn as nn
+import torch.nn.functional as F
+
+
+class Model(nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.conv1 = nn.Conv2d(1, 20, 5)
+        self.conv2 = nn.Conv2d(20, 20, 5)
+
+    def forward(self, x):
+        x = F.relu(self.conv1(x))
+        return F.relu(self.conv2(x))
+
+ +

Submodules assigned in this way will be registered, and will also have their +parameters converted when you call to(), etc.

+ +
+ +

As per the example above, an __init__() call to the parent class +must be made before assignment on the child.

+ +
+ +

:ivar training: Boolean represents whether this module is in training or + evaluation mode. +:vartype training: bool

+
+ + +
+
+ + HadamardQuantLinear( in_features: int, out_features: int, bias: bool = True, device=None, dtype=None) + + +
+ + +

a naive implementation of hadamard transformation and quantization

+ +
Arguments:
+ +
    +
  • in_features: in feature number
  • +
  • out_features: out feature number
  • +
  • bias: whether use bias
  • +
  • device: weight device
  • +
  • dtype: weight dtype
  • +
  • impl: implementation of hadamard quantization
  • +
+
+ + +
+
+
+ + def + forward(self, input: torch.Tensor) -> torch.Tensor: + + +
+ + +

Define the computation performed at every call.

+ +

Should be overridden by all subclasses.

+ +
+ +

Although the recipe for forward pass needs to be defined within +this function, one should call the Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+ +
+
+ + +
+
+
+ + def + extra_repr(self) -> str: + + +
+ + +

Return the extra representation of the module.

+ +

To print customized extra information, you should re-implement +this method in your own modules. Both single-line and multi-line +strings are acceptable.

+
+ + +
+
+
+ + \ No newline at end of file diff --git a/docs/linghe/facade/loss.html b/docs/linghe/facade/loss.html new file mode 100644 index 0000000..21cda4a --- /dev/null +++ b/docs/linghe/facade/loss.html @@ -0,0 +1,88 @@ + + + + + + + linghe.facade.loss API documentation + + + + + + + + + +
+
+

+linghe.facade.loss

+ +

Copyright (c) Ant Financial Service Group and its affiliates.

+
+ + + + +
+
+
+ + def + softmax_cross_entropy(logits: torch.Tensor, labels: torch.Tensor, inplace: bool = False): + + +
+ + +

softmax cross entropy

+ +
Arguments:
+ +
    +
  • logits: logits tensor, shape [...,dim]
  • +
  • labels: labels tensor, shape [...]
  • +
  • inplace: update gradient in the logits tensor if True
  • +
+ +
Returns:
+ +
+

per token loss

+
+
+ + +
+
+ + \ No newline at end of file diff --git a/docs/linghe/facade/norm.html b/docs/linghe/facade/norm.html new file mode 100644 index 0000000..e858009 --- /dev/null +++ b/docs/linghe/facade/norm.html @@ -0,0 +1,122 @@ + + + + + + + linghe.facade.norm API documentation + + + + + + + + + +
+
+

+linghe.facade.norm

+ +

Copyright (c) Ant Financial Service Group and its affiliates.

+
+ + + + +
+
+
+ + def + rms_norm(x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-06): + + +
+ + +

rms norm of x with weight

+ +
Arguments:
+ +
    +
  • x: activation tensor
  • +
  • weight: weight tensor
  • +
  • eps: epsilon for RMS
  • +
+ +
Returns:
+ +
+

rms output

+
+
+ + +
+
+
+ + def + group_norm_gate( attn_output: torch.Tensor, gate: torch.Tensor, weight: torch.Tensor, eps: float = 1e-06, group_size: int = 4): + + +
+ + +

return group_rms_norm(transpose(attn_output, [0,1]), weight) * sigmoid(gate)

+ +
Arguments:
+ +
    +
  • attn_output: output of core attn, shape [bs, length, n_heads, head_dim]
  • +
  • gate: gate tensor for attention output, shape [length, bs, dim]
  • +
  • weight: weight of RMS norm, shape [dim]
  • +
  • eps: epsilon for RMS
  • +
  • group_size: group size of group RMS norm
  • +
+ +
Returns:
+ +
+

output with shape [length, bs, dim]

+
+
+ + +
+
+ + \ No newline at end of file diff --git a/docs/linghe/facade/rope.html b/docs/linghe/facade/rope.html new file mode 100644 index 0000000..0d39eca --- /dev/null +++ b/docs/linghe/facade/rope.html @@ -0,0 +1,94 @@ + + + + + + + linghe.facade.rope API documentation + + + + + + + + + +
+
+

+linghe.facade.rope

+ +

Copyright (c) Ant Financial Service Group and its affiliates.

+
+ + + + +
+
+
+ + def + qk_norm_half_rope( qkv: torch.Tensor, q_norm_weight: torch.Tensor, k_norm_weight: torch.Tensor, freqs: torch.Tensor, H: int = 32, h: int = 4, eps: float = 1e-06): + + +
+ + +

split qkv to q/k/v, apply qk norm and half rope to q/k, transpose q/k/v to flash-attention layout

+ +
Arguments:
+ +
    +
  • qkv: QKV tensor with size of [S, B, dim], heads are interleaved
  • +
  • q_norm_weight: rms norm weight for query
  • +
  • k_norm_weight: rms norm weight for key
  • +
  • freqs: Freqs tensor based on half dim.
  • +
  • H: Number of attention heads.
  • +
  • h: Number of key/value heads.
  • +
  • eps: epsilon value for L2 normalization.
  • +
+ +
Returns:
+ +
+

qo: shape [B, S, H, head_dim] + ko: shape [B, S, h, head_dim] + vo: shape [B, S, h, head_dim]

+
+
+ + +
+
+ + \ No newline at end of file diff --git a/docs/linghe/facade/smooth_quant_linear.html b/docs/linghe/facade/smooth_quant_linear.html new file mode 100644 index 0000000..2752e47 --- /dev/null +++ b/docs/linghe/facade/smooth_quant_linear.html @@ -0,0 +1,179 @@ + + + + + + + linghe.facade.smooth_quant_linear API documentation + + + + + + + + + +
+
+

+linghe.facade.smooth_quant_linear

+ +

Copyright (c) Ant Financial Service Group and its affiliates.

+
+ + + + +
+
+
+ + class + QuantLinear(torch.nn.modules.module.Module): + + +
+ + +

Base class for all neural network modules.

+ +

Your models should also subclass this class.

+ +

Modules can also contain other Modules, allowing them to be nested in +a tree structure. You can assign the submodules as regular attributes::

+ +
import torch.nn as nn
+import torch.nn.functional as F
+
+
+class Model(nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.conv1 = nn.Conv2d(1, 20, 5)
+        self.conv2 = nn.Conv2d(20, 20, 5)
+
+    def forward(self, x):
+        x = F.relu(self.conv1(x))
+        return F.relu(self.conv2(x))
+
+ +

Submodules assigned in this way will be registered, and will also have their +parameters converted when you call to(), etc.

+ +
+ +

As per the example above, an __init__() call to the parent class +must be made before assignment on the child.

+ +
+ +

:ivar training: Boolean represents whether this module is in training or + evaluation mode. +:vartype training: bool

+
+ + +
+
+ + QuantLinear( in_features: int, out_features: int, bias: bool = True, device=None, dtype=None) + + +
+ + +

Initialize internal Module state, shared by both nn.Module and ScriptModule.

+
+ + +
+
+
+ + def + forward(self, input: torch.Tensor) -> torch.Tensor: + + +
+ + +

Define the computation performed at every call.

+ +

Should be overridden by all subclasses.

+ +
+ +

Although the recipe for forward pass needs to be defined within +this function, one should call the Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+ +
+
+ + +
+
+
+ + def + extra_repr(self) -> str: + + +
+ + +

Return the extra representation of the module.

+ +

To print customized extra information, you should re-implement +this method in your own modules. Both single-line and multi-line +strings are acceptable.

+
+ + +
+
+
+ + \ No newline at end of file diff --git a/docs/linghe/facade/transpose.html b/docs/linghe/facade/transpose.html new file mode 100644 index 0000000..9d645db --- /dev/null +++ b/docs/linghe/facade/transpose.html @@ -0,0 +1,86 @@ + + + + + + + linghe.facade.transpose API documentation + + + + + + + + + +
+
+

+linghe.facade.transpose

+ +

Copyright (c) Ant Financial Service Group and its affiliates.

+
+ + + + +
+
+
+ + def + transpose_dim01(x): + + +
+ + +

transpose a tensor with the first two dims, x.ndims should not greater than 4

+ +
Arguments:
+ +
    +
  • x: input tensor
  • +
+ +
Returns:
+ +
+

a transposed tensor

+
+
+ + +
+
+ + \ No newline at end of file diff --git a/docs/linghe/gemm.html b/docs/linghe/gemm.html new file mode 100644 index 0000000..3175a31 --- /dev/null +++ b/docs/linghe/gemm.html @@ -0,0 +1,56 @@ + + + + + + + linghe.gemm API documentation + + + + + + + + + +
+
+

+linghe.gemm

+ + + + + +
+
+ + \ No newline at end of file diff --git a/docs/linghe/gemm/blockwise_fp8_gemm.html b/docs/linghe/gemm/blockwise_fp8_gemm.html new file mode 100644 index 0000000..69ab790 --- /dev/null +++ b/docs/linghe/gemm/blockwise_fp8_gemm.html @@ -0,0 +1,56 @@ + + + + + + + linghe.gemm.blockwise_fp8_gemm API documentation + + + + + + + + + +
+
+

+linghe.gemm.blockwise_fp8_gemm

+ +

Copyright (c) Ant Financial Service Group and its affiliates.

+
+ + + + +
+
+ + \ No newline at end of file diff --git a/docs/linghe/gemm/channelwise_fp8_gemm.html b/docs/linghe/gemm/channelwise_fp8_gemm.html new file mode 100644 index 0000000..c8c6831 --- /dev/null +++ b/docs/linghe/gemm/channelwise_fp8_gemm.html @@ -0,0 +1,93 @@ + + + + + + + linghe.gemm.channelwise_fp8_gemm API documentation + + + + + + + + + +
+
+

+linghe.gemm.channelwise_fp8_gemm

+ +

Copyright (c) Ant Financial Service Group and its affiliates.

+
+ + + + +
+
+
+ + def + triton_scaled_mm( a: torch.Tensor, b: torch.Tensor, a_scale: torch.Tensor, b_scale: torch.Tensor, out_dtype=torch.float32, c=None, accum=True): + + +
+ + +

similar to torch._scaled_mm, support accumulating gemm output to c + and low precision output tensor

+ +
Arguments:
+ +
    +
  • a: left fp8 tensor
  • +
  • b: right fp8 tensor, column-major
  • +
  • a_scale: fp32 scale of a
  • +
  • b_scale: fp32 scale of b
  • +
  • out_dtype: output tensor dtype
  • +
  • c: output tensor
  • +
  • accum: accumulate output on c if True
  • +
+ +
Returns:
+ +
+

c: output tensor

+
+
+ + +
+
+ + \ No newline at end of file diff --git a/docs/linghe/gemm/fp32_gemm.html b/docs/linghe/gemm/fp32_gemm.html new file mode 100644 index 0000000..f0f3faf --- /dev/null +++ b/docs/linghe/gemm/fp32_gemm.html @@ -0,0 +1,222 @@ + + + + + + + linghe.gemm.fp32_gemm API documentation + + + + + + + + + +
+
+

+linghe.gemm.fp32_gemm

+ +

Copyright (c) Ant Financial Service Group and its affiliates.

+
+ + + + +
+
+
+ + def + triton_fp32_gemm(a: torch.Tensor, b: torch.Tensor): + + +
+ + +

return fp32 gemm result with fp16/bf16 inputs, + it's mainly used for MoE router GEMM + and DO NOT suitable for large size GEMM

+ +
Arguments:
+ +
    +
  • a: left matrix with fp16/bf16 precision
  • +
  • b: right matrix with fp16/bf16 precision
  • +
+ +
Returns:
+ +
+

c: output with fp32 precision

+
+
+ + +
+
+
+ + def + triton_fp32_gemm_for_backward(a: torch.Tensor, b: torch.Tensor): + + +
+ + +

mix precision gemm for backward, a@b.float()

+ +
Arguments:
+ +
    +
  • a: input gradient, fp32
  • +
  • b: gemm weight, bf16/fp16
  • +
+ +
Returns:
+ +
+

c: gradient of activation

+
+
+ + +
+
+
+ + def + triton_fp32_gemm_for_update(a: torch.Tensor, b: torch.Tensor): + + +
+ + +

mix precision gemm for updaing weight

+ +
Arguments:
+ +
    +
  • a: gradient of output, fp32
  • +
  • b: input activation, bf16/fp16
  • +
+ +
Returns:
+ +
+

c: gradient of weight

+
+
+ + +
+
+
+ + def + triton_scaled_fp32_gemm(a: torch.Tensor, b: torch.Tensor, scale: torch.Tensor): + + +
+ + +

c = (ascale[:,None])b +this kernel is used to fuse RMSNorm and quantization in MoE layer +native implementation: + y = rms_norm(x), + y_q = quantization(y), + router_logits = y@w +we can not fuse rms_norm and quantization +as we still need bf16 y for moe router gemm +fused implementation: + y_q, rms = quantization(rms_norm(x)) + router_logits = (x/rms)@y +so we need a scaled fp32 gemm kernel

+ +
Arguments:
+ +
    +
  • a: activation tensor
  • +
  • b: weight tensor
  • +
  • scale: scale for activation tensor, 1/rms
  • +
+ +

Returns:

+
+ + +
+
+
+ + def + triton_scaled_fp32_gemm_for_update(a: torch.Tensor, b: torch.Tensor, scale: torch.Tensor): + + +
+ + +

see triton_scaled_fp32_gemm

+ +
Arguments:
+ +
    +
  • a: y
  • +
  • b: activation before RMS norm
  • +
  • scale: 1/rms
  • +
+ +
Returns:
+ +
+

dw

+
+
+ + +
+
+ + \ No newline at end of file diff --git a/docs/linghe/quant.html b/docs/linghe/quant.html new file mode 100644 index 0000000..bd6573d --- /dev/null +++ b/docs/linghe/quant.html @@ -0,0 +1,58 @@ + + + + + + + linghe.quant API documentation + + + + + + + + + +
+
+

+linghe.quant

+ + + + + +
+
+ + \ No newline at end of file diff --git a/docs/linghe/quant/block.html b/docs/linghe/quant/block.html new file mode 100644 index 0000000..a562b18 --- /dev/null +++ b/docs/linghe/quant/block.html @@ -0,0 +1,89 @@ + + + + + + + linghe.quant.block API documentation + + + + + + + + + +
+
+

+linghe.quant.block

+ +

Copyright (c) Ant Financial Service Group and its affiliates.

+
+ + + + +
+
+
+ + def + triton_block_quant(x, block_size=128, round_scale=False): + + +
+ + +

blockwise quantize x

+ +
Arguments:
+ +
    +
  • x: input tensor
  • +
  • block_size: block wise
  • +
  • round_scale: whether round scale to power of 2
  • +
+ +
Returns:
+ +
+

y: quantized tensor, float8_e4m3fn + s: quantization scale, float32

+
+
+ + +
+
+ + \ No newline at end of file diff --git a/docs/linghe/quant/block/block.html b/docs/linghe/quant/block/block.html new file mode 100644 index 0000000..58f0ab6 --- /dev/null +++ b/docs/linghe/quant/block/block.html @@ -0,0 +1,56 @@ + + + + + + + linghe.quant.block.block API documentation + + + + + + + + + +
+
+

+linghe.quant.block.block

+ +

Copyright (c) Ant Financial Service Group and its affiliates.

+
+ + + + +
+
+ + \ No newline at end of file diff --git a/docs/linghe/quant/block/group.html b/docs/linghe/quant/block/group.html new file mode 100644 index 0000000..11958c6 --- /dev/null +++ b/docs/linghe/quant/block/group.html @@ -0,0 +1,56 @@ + + + + + + + linghe.quant.block.group API documentation + + + + + + + + + +
+
+

+linghe.quant.block.group

+ +

Copyright (c) Ant Financial Service Group and its affiliates.

+
+ + + + +
+
+ + \ No newline at end of file diff --git a/docs/linghe/quant/channel.html b/docs/linghe/quant/channel.html new file mode 100644 index 0000000..80243be --- /dev/null +++ b/docs/linghe/quant/channel.html @@ -0,0 +1,152 @@ + + + + + + + linghe.quant.channel API documentation + + + + + + + + + +
+
+

+linghe.quant.channel

+ +

Copyright (c) Ant Financial Service Group and its affiliates.

+
+ + + + +
+
+
+ + def + triton_row_quant(x, round_scale=False): + + +
+ + +

rowwise quantize x

+ +
Arguments:
+ +
    +
  • x: input x
  • +
  • round_scale: whether round scale to power of 2
  • +
+ +
Returns:
+ +
+

x_q: quantized tensor + x_scale: quantization scale

+
+
+ + +
+
+
+ + def + triton_tokenwise_row_quant(x, out=None, scale=None, round_scale=False): + + +
+ + +

rowwise quantize x with power of 2 dim size

+ +
Arguments:
+ +
    +
  • x: input x
  • +
  • round_scale: whether round scale to power of 2
  • +
+ +
Returns:
+ +
+

out: quantized tensor + scale: quantization scale

+
+
+ + +
+
+
+ + def + triton_transpose_row_quant(x, round_scale=False): + + +
+ + +

transpose x and row quantize x

+ +
Arguments:
+ +
    +
  • x: input x
  • +
  • round_scale: whether round scale to power of 2
  • +
+ +
Returns:
+ +
+

x_q: quantized tensor + x_scale: quantization scale

+
+
+ + +
+
+ + \ No newline at end of file diff --git a/docs/linghe/quant/channel/channel.html b/docs/linghe/quant/channel/channel.html new file mode 100644 index 0000000..7030e74 --- /dev/null +++ b/docs/linghe/quant/channel/channel.html @@ -0,0 +1,56 @@ + + + + + + + linghe.quant.channel.channel API documentation + + + + + + + + + +
+
+

+linghe.quant.channel.channel

+ +

Copyright (c) Ant Financial Service Group and its affiliates.

+
+ + + + +
+
+ + \ No newline at end of file diff --git a/docs/linghe/quant/group.html b/docs/linghe/quant/group.html new file mode 100644 index 0000000..e0191d6 --- /dev/null +++ b/docs/linghe/quant/group.html @@ -0,0 +1,89 @@ + + + + + + + linghe.quant.group API documentation + + + + + + + + + +
+
+

+linghe.quant.group

+ +

Copyright (c) Ant Financial Service Group and its affiliates.

+
+ + + + +
+
+
+ + def + triton_group_quant(x, dtype=torch.float8_e4m3fn, group_size=128, round_scale=False): + + +
+ + +

groupwise quantize x, group is in under rowwise format

+ +
Arguments:
+ +
    +
  • x: input tensor
  • +
  • group_size: group wise
  • +
  • round_scale: whether round scale to power of 2
  • +
+ +
Returns:
+ +
+

y: quantized tensor, float8_e4m3fn + s: quantization scale, float32

+
+
+ + +
+
+ + \ No newline at end of file diff --git a/docs/linghe/quant/hadamard.html b/docs/linghe/quant/hadamard.html new file mode 100644 index 0000000..917cf30 --- /dev/null +++ b/docs/linghe/quant/hadamard.html @@ -0,0 +1,90 @@ + + + + + + + linghe.quant.hadamard API documentation + + + + + + + + + +
+
+

+linghe.quant.hadamard

+ +

Copyright (c) Ant Financial Service Group and its affiliates.

+
+ + + + +
+
+
+ + def + triton_hadamard_quant(x, hm): + + +
+ + +

apply hadamard transformation and then quantize transformed tensor

+ +
Arguments:
+ +
    +
  • x: input tensor
  • +
  • hm: hamadard matrix
  • +
+ +
Returns:
+ +
+

x_q: rowwise quantized tensor of non-transposed x + x_scale: rowwise quantization scale of non-transposed x + xt_q: columnwise quantized tensor of transposed x + xt_scale: columnwise quantization scale of transposed x

+
+
+ + +
+
+ + \ No newline at end of file diff --git a/docs/linghe/quant/hadamard/seperate_hadamard.html b/docs/linghe/quant/hadamard/seperate_hadamard.html new file mode 100644 index 0000000..dce669e --- /dev/null +++ b/docs/linghe/quant/hadamard/seperate_hadamard.html @@ -0,0 +1,56 @@ + + + + + + + linghe.quant.hadamard.seperate_hadamard API documentation + + + + + + + + + +
+
+

+linghe.quant.hadamard.seperate_hadamard

+ +

Copyright (c) Ant Financial Service Group and its affiliates.

+
+ + + + +
+
+ + \ No newline at end of file diff --git a/docs/linghe/quant/smooth.html b/docs/linghe/quant/smooth.html new file mode 100644 index 0000000..c903c18 --- /dev/null +++ b/docs/linghe/quant/smooth.html @@ -0,0 +1,56 @@ + + + + + + + linghe.quant.smooth API documentation + + + + + + + + + +
+
+

+linghe.quant.smooth

+ +

Copyright (c) Ant Financial Service Group and its affiliates.

+
+ + + + +
+
+ + \ No newline at end of file diff --git a/docs/linghe/quant/smooth/reused_smooth.html b/docs/linghe/quant/smooth/reused_smooth.html new file mode 100644 index 0000000..927018e --- /dev/null +++ b/docs/linghe/quant/smooth/reused_smooth.html @@ -0,0 +1,56 @@ + + + + + + + linghe.quant.smooth.reused_smooth API documentation + + + + + + + + + +
+
+

+linghe.quant.smooth.reused_smooth

+ +

Copyright (c) Ant Financial Service Group and its affiliates.

+
+ + + + +
+
+ + \ No newline at end of file diff --git a/docs/linghe/quant/smooth/seperate_smooth.html b/docs/linghe/quant/smooth/seperate_smooth.html new file mode 100644 index 0000000..5ba5130 --- /dev/null +++ b/docs/linghe/quant/smooth/seperate_smooth.html @@ -0,0 +1,56 @@ + + + + + + + linghe.quant.smooth.seperate_smooth API documentation + + + + + + + + + +
+
+

+linghe.quant.smooth.seperate_smooth

+ +

Copyright (c) Ant Financial Service Group and its affiliates.

+
+ + + + +
+
+ + \ No newline at end of file diff --git a/docs/linghe/utils.html b/docs/linghe/utils.html new file mode 100644 index 0000000..e37e3fe --- /dev/null +++ b/docs/linghe/utils.html @@ -0,0 +1,64 @@ + + + + + + + linghe.utils API documentation + + + + + + + + + +
+
+

+linghe.utils

+ + + + + +
+
+ + \ No newline at end of file diff --git a/docs/linghe/utils/add.html b/docs/linghe/utils/add.html new file mode 100644 index 0000000..9c8fe47 --- /dev/null +++ b/docs/linghe/utils/add.html @@ -0,0 +1,88 @@ + + + + + + + linghe.utils.add API documentation + + + + + + + + + +
+
+

+linghe.utils.add

+ +

Copyright (c) Ant Financial Service Group and its affiliates.

+
+ + + + +
+
+
+ + def + triton_inplace_add(x: torch.Tensor, y: torch.Tensor, accum: bool = True): + + +
+ + +

inplace add y to x

+ +
Arguments:
+ +
    +
  • x: Tensor
  • +
  • y: Tensor
  • +
  • accum: x += y if accum=True else x.copy_(y)
  • +
+ +
Returns:
+ +
+

updated x

+
+
+ + +
+
+ + \ No newline at end of file diff --git a/docs/linghe/utils/dot.html b/docs/linghe/utils/dot.html new file mode 100644 index 0000000..49a7b28 --- /dev/null +++ b/docs/linghe/utils/dot.html @@ -0,0 +1,88 @@ + + + + + + + linghe.utils.dot API documentation + + + + + + + + + +
+
+

+linghe.utils.dot

+ +

Copyright (c) Ant Financial Service Group and its affiliates.

+
+ + + + +
+
+
+ + def + triton_dot(x, y): + + +
+ + +

vector dot multiply, output = sum(x*y, 1), +it is used to calculate gradient of router weight

+ +
Arguments:
+ +
    +
  • x:
  • +
  • y:
  • +
+ +
Returns:
+ +
+

output of sum(x*y, 1)

+
+
+ + +
+
+ + \ No newline at end of file diff --git a/docs/linghe/utils/gather.html b/docs/linghe/utils/gather.html new file mode 100644 index 0000000..8f18e38 --- /dev/null +++ b/docs/linghe/utils/gather.html @@ -0,0 +1,342 @@ + + + + + + + linghe.utils.gather API documentation + + + + + + + + + +
+
+

+linghe.utils.gather

+ +

Copyright (c) Ant Financial Service Group and its affiliates.

+
+ + + + +
+
+
+ + def + triton_make_row_id_map(routing_map: torch.Tensor, multiple_of: int = 1): + + +
+ + +

make row id map, values in the tensor are the row indices

+ +
Arguments:
+ +
    +
  • routing_map: a tensor of 0/1 values, 1 indicates routed
  • +
  • multiple_of: padding the tokens of each expert to multiple of this value
  • +
+ +
Returns:
+ +
+

row id map with shape [n_tokens, n_experts]

+
+
+ + +
+
+
+ + def + triton_make_row_id_map_and_indices(routing_map: torch.Tensor, num_out_tokens: int, multiple_of: int = 1): + + +
+ + +

similar with triton_make_row_id_map, but output an indices tensor as well

+ +
Arguments:
+ +
    +
  • routing_map: [n_tokens, n_experts]
  • +
  • num_out_tokens: sum(round_up_to(n_tokens, multiple_of))
  • +
  • multiple_of: padding the tokens of each expert to this value
  • +
+ +
Returns:
+ +
+

row_in_map: [n_tokens, n_experts] + row_indices: [num_out_tokens]

+
+
+ + +
+
+
+ + def + triton_index_select(x, indices, scale=None, out=None, scale_out=None): + + +
+ + +

index select for quantized tensor

+ +
Arguments:
+ +
    +
  • x: [bs, dim]
  • +
  • indices: [K]
  • +
  • scale: [bs]
  • +
+ +
Returns:
+ +
+

out: output of selected x + scale_out: scale of selected scale

+
+
+ + +
+
+
+ + def + triton_permute_with_mask_map( inp: torch.Tensor, scale: torch.Tensor, probs: torch.Tensor, row_id_map: torch.Tensor, num_out_tokens: int, contiguous: bool = True, tokens_per_expert: Optional[torch.Tensor] = None): + + +
+ + +

gather quantized tensor with row id map

+ +
Arguments:
+ +
    +
  • inp: [num_tokens, hidden_size], rowwise quantized tensor
  • +
  • scale: [num_tokens], quantization scale
  • +
  • probs: router prob, used as weight
  • +
  • row_id_map: [n_experts, num_tokens] +index >= 0: row index of output tensor +index == -1: ignore +Note: index may not be contiguous
  • +
  • num_out_tokens: output token count, including padding tokens
  • +
  • contiguous: whether indices in row_id_map is contiguous, +False means padded
  • +
  • tokens_per_expert: [num_experts], token count per expert, +non-blocking cuda tensor
  • +
+ +
Returns:
+ +
+

output: permuted quantized tensor + permuted_scale: permuted quantization scale + permuted_probs: permuted router prob

+
+
+ + +
+
+
+ + def + triton_batch_transpose_smooth_permute_with_indices( x, scale, org_smooth_scale, smooth_scales, indices, token_count_per_expert, splits, x_q=None, x_scale=None, round_scale=False): + + +
+ + +

used for smooth quantization backward in megatron 0.12, +x is gathered, requantized, padded to multiple of 32 and tranposed

+ +
Arguments:
+ +
    +
  • x: dy, [bs, dim], it is smooth quantized
  • +
  • scale: [bs], quantized scale
  • +
  • org_smooth_scale: [dim]
  • +
  • smooth_scales: [n_experts, dim]
  • +
  • indices: [sum(tokens_per_experts)]
  • +
  • token_count_per_expert: [n_experts], tensor of token count per expert
  • +
  • splits: [n_experts], list of token_count_per_expert
  • +
  • round_scale: round quantization scale to power of 2
  • +
+ +
Returns:
+ +
+

x_q: [sum(roundup(tokens_per_experts)) * dim] + x_scale: [sum(roundup(tokens_per_experts))]

+
+
+ + +
+
+
+ + def + triton_smooth_weighted_permute_with_indices( grads, tokens, smooth_scales, token_count_per_expert, indices, x_q=None, x_scale=None, x_sum=None, reverse=False, round_scale=False): + + +
+ + +

select and smooth and quant, used in megatron 0.11 all2all moe

+ +
Arguments:
+ +
    +
  • grads: [bs, dim]
  • +
  • tokens: [bs, dim]
  • +
  • smooth_scales: [n_experts, dim]
  • +
  • token_count_per_expert: [n_experts]
  • +
  • indices: [n_experts*topk]
  • +
  • reverse: whether scale is 1/scale
  • +
  • round_scale: whether round scale to power of 2
  • +
+ +
Returns:
+ +
+

x_q: [bs*topk, dim] + x_scale: [bstopk] + x_sum: [bstopk]

+
+
+ + +
+
+
+ + def + triton_smooth_permute_with_indices( grad_data, grad_scale, smooth_scales, token_count_per_expert, indices, x_q=None, x_scale=None, reverse=False, round_scale=False): + + +
+ + +

select and smooth and quant

+ +
Arguments:
+ +
    +
  • grad_data: [bs, dim]
  • +
  • grad_scale: [bs]
  • +
  • smooth_scales: [n_experts, dim]
  • +
  • token_count_per_expert: [n_experts]
  • +
  • indices: [n_experts*topk]
  • +
  • x_q: [bs*topk, dim]
  • +
  • x_scale: [bs*topk]
  • +
  • reverse:
  • +
  • round_scale:
  • +
+ +

Returns:

+
+ + +
+
+
+ + def + triton_smooth_permute_with_mask_map( inp: torch.Tensor, row_id_map: torch.Tensor, scale: torch.Tensor, num_tokens: int, num_experts: int, num_out_tokens: int, hidden_size: int, smooth_scales: torch.Tensor, reverse=True, round_scale=False): + + +
+ + +

gather and optional dequant and smooth quant

+ +
Arguments:
+ +
    +
  • inp: [num_tokens, hidden_size], rowwise quantized tensor
  • +
  • row_id_map: [n_experts, num_tokens], indices
  • +
  • scale: [num_tokens, hs], rowwise_scale_inv, optional
  • +
  • num_tokens: [n_experts]
  • +
  • num_experts:
  • +
  • num_out_tokens:
  • +
  • hidden_size:
  • +
  • smooth_scales: [n_experts, hidden_size]
  • +
  • reverse:
  • +
  • round_scale:
  • +
+ +

Returns:

+
+ + +
+
+ + \ No newline at end of file diff --git a/docs/linghe/utils/loss.html b/docs/linghe/utils/loss.html new file mode 100644 index 0000000..73e62df --- /dev/null +++ b/docs/linghe/utils/loss.html @@ -0,0 +1,121 @@ + + + + + + + linghe.utils.loss API documentation + + + + + + + + + +
+
+

+linghe.utils.loss

+ +

Copyright (c) Ant Financial Service Group and its affiliates.

+
+ + + + +
+
+
+ + def + triton_softmax_cross_entropy_forward(logits, labels): + + +
+ + +

compute token-wise softmax cross entropy loss

+ +
Arguments:
+ +
    +
  • logits: logits tensor
  • +
  • labels: labels tensor
  • +
+ +
Returns:
+ +
+

loss of each token

+
+
+ + +
+
+
+ + def + triton_softmax_cross_entropy_backward(logits, labels, sum_exp, max_logit, input_grad, output_grad=None): + + +
+ + +

backward of softmax cross entropy loss

+ +
Arguments:
+ +
    +
  • logits: logit tensor, [bs, dim]
  • +
  • labels: label tensor, [bs]
  • +
  • sum_exp: [bs]
  • +
  • max_logit: [bs]
  • +
  • input_grad: gradient, [bs, dim]
  • +
+ +
Returns:
+ +
+

output_grad: [bs, dim]

+
+
+ + +
+
+ + \ No newline at end of file diff --git a/docs/linghe/utils/norm.html b/docs/linghe/utils/norm.html new file mode 100644 index 0000000..63e9fb9 --- /dev/null +++ b/docs/linghe/utils/norm.html @@ -0,0 +1,164 @@ + + + + + + + linghe.utils.norm API documentation + + + + + + + + + +
+
+

+linghe.utils.norm

+ + + + + +
+
+
+ + def + triton_rms_norm_forward(x, weight, eps=1e-06, out=None): + + +
+ + +

rms norm

+ +
Arguments:
+ +
    +
  • x: input tensor
  • +
  • weight: weight of rms norm
  • +
  • eps: epsilon of rms norm
  • +
+ +
Returns:
+ +
+

out: output tensor

+
+
+ + +
+
+
+ + def + triton_rms_norm_and_block_quant_forward( x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-06, out: Optional[torch.Tensor] = None, scale: Optional[torch.Tensor] = None, rms: Optional[torch.Tensor] = None, round_scale: bool = False, output_mode: int = 2): + + +
+ + +

Fused RMSNorm forward and block quantization.

+ +
Arguments:
+ +
    +
  • x: Input tensor, shape [M, N]
  • +
  • weight: RMSNorm weight, shape [N]
  • +
  • eps: epsilon value for L2 normalization.
  • +
  • out: output of quantization data
  • +
  • scale: output of quantization scale.
  • +
  • rms: output of rms
  • +
  • round_scale: Set whether to force power of 2 scales.
  • +
  • output_mode: one of {0, 1, 2}. +0: only output non-transpose tensor +1: only output transposed tensor +2: return both
  • +
+ +
Returns:
+ +
+

out: quantization data + scale: quantization scale + rms: Reciprocal of the root mean square of the input calculated over the last dimension. + transpose_output: quantization data of transposed gradient + transpose_scale: quantization scale of transposed gradient

+
+
+ + +
+
+
+ + def + triton_group_norm_gate_forward(x: torch.Tensor, gate, weight, eps=1e-06, group_size=4): + + +
+ + +

norm and gate in linear attention

+ +
Arguments:
+ +
    +
  • x: output of attn, [bs, length, n_heads, head_dim]
  • +
  • gate: gate tensor, [length, bs, dim]
  • +
  • weight: rms norm weight, [dim]
  • +
  • eps: epsilon of rms norm
  • +
  • group_size: group size of group rms norm
  • +
+ +
Returns:
+ +
+

output tensor

+
+
+ + +
+
+ + \ No newline at end of file diff --git a/docs/linghe/utils/rearange.html b/docs/linghe/utils/rearange.html new file mode 100644 index 0000000..dab027f --- /dev/null +++ b/docs/linghe/utils/rearange.html @@ -0,0 +1,91 @@ + + + + + + + linghe.utils.rearange API documentation + + + + + + + + + +
+
+

+linghe.utils.rearange

+ +

Copyright (c) Ant Financial Service Group and its affiliates.

+
+ + + + +
+
+
+ + def + triton_split_and_cat(x, counts, indices, scales=None): + + +
+ + +

split x to multiple tensors and cat with indices, +it is used for permutation in moe

+ +
Arguments:
+ +
    +
  • x: [bs, dim]
  • +
  • counts: [n_split]
  • +
  • indices: [n_split]
  • +
  • scales: [bs]
  • +
+ +
Returns:
+ +
+

y: output tensor + output_scales: output scales if scales is not None

+
+
+ + +
+
+ + \ No newline at end of file diff --git a/docs/linghe/utils/reduce.html b/docs/linghe/utils/reduce.html new file mode 100644 index 0000000..d000c2d --- /dev/null +++ b/docs/linghe/utils/reduce.html @@ -0,0 +1,151 @@ + + + + + + + linghe.utils.reduce API documentation + + + + + + + + + +
+
+

+linghe.utils.reduce

+ +

Copyright (c) Ant Financial Service Group and its affiliates.

+
+ + + + +
+
+
+ + def + triton_abs_max(x, scale=None, smooth_scale=None, min_value=1e-30, axis=0): + + +
+ + +

columnwise abs max of x, it is used in smooth quantization

+ +
Arguments:
+ +
    +
  • x: input tensor, may be quantized tensor
  • +
  • scale: quantization scale if x is quantized
  • +
  • smooth_scale: optional smooth scale
  • +
  • min_value: output = max(max(abs(x,0)), min_value)
  • +
  • axis: reduce axis
  • +
+ +
Returns:
+ +
+

max tensor

+
+
+ + +
+
+
+ + def + triton_batch_count_zero(xs): + + +
+ + +

count zero in tensor list, it is used to monitor zeros in gradient tensor

+ +
Arguments:
+ +
    +
  • xs: input tensors
  • +
+ +
Returns:
+ +
+

a single-value int64 tensor

+
+
+ + +
+
+
+ + def + triton_batch_sum_with_ord(xs, ord=2): + + +
+ + +

return sum(abs(x)**ord).

+ +
Arguments:
+ +
    +
  • xs: Tensor lists.
  • +
  • ord: the order of tensor.
  • +
+ +
Returns:
+ +
+

a single-value fp32 tensor

+
+
+ + +
+
+ + \ No newline at end of file diff --git a/docs/linghe/utils/rope.html b/docs/linghe/utils/rope.html new file mode 100644 index 0000000..df394a3 --- /dev/null +++ b/docs/linghe/utils/rope.html @@ -0,0 +1,172 @@ + + + + + + + linghe.utils.rope API documentation + + + + + + + + + +
+
+

+linghe.utils.rope

+ +

Copyright (c) Ant Financial Service Group and its affiliates.

+
+ + + + +
+
+
+ + def + triton_half_rope_forward(q, k, freqs): + + +
+ + +

apply norm to qk, then apply half rope to qk

+ +
Arguments:
+ +
    +
  • q: query tensor, [len, bs, q_head, head_dim]
  • +
  • k: key tensor, [len, bs, kv_head, head_dim]
  • +
  • freqs: rope freqs
  • +
+ +
Returns:
+ +
+

qo: + ko:

+
+
+ + +
+
+
+ + def + triton_qk_norm_and_half_rope_forward( qkv, q_norm_weight, k_norm_weight, freqs, H=32, h=4, eps=1e-06, interleave=True, transpose=False): + + +
+ + +

split qkv to q/k/v, apply qk norm and half rope to q/k, + transpose q/k/v to flash-attention layout

+ +
Arguments:
+ +
    +
  • qkv: QKV tensor with size of [S, B, dim], heads are interleaved
  • +
  • q_norm_weight: rms norm weight for query
  • +
  • k_norm_weight: rms norm weight for key
  • +
  • freqs: Freqs tensor based on half dim.
  • +
  • H: Number of attention heads.
  • +
  • h: Number of key/value heads.
  • +
  • eps: epsilon value for L2 normalization.
  • +
  • interleave: whether head of qkv is interleaved, i.e., [qqkvqqkv]
  • +
  • transpose: whether qkv is tranposed, i.e., [S, B, dim], +only support transpose format currently
  • +
+ +
Returns:
+ +
+

qo: shape [B, S, H, head_dim] + ko: shape [B, S, h, head_dim] + vo: shape [B, S, h, head_dim]

+
+
+ + +
+
+
+ + def + triton_qk_norm_and_half_rope_backward( gq, gk, gv, qkv, q_norm_weight, k_norm_weight, freqs, eps=1e-06, transpose=False, interleave=True): + + +
+ + +

backward kernel of triton_qk_norm_and_half_rope_forward

+ +
Arguments:
+ +
    +
  • gq: gradient of qo, [len, bs, q_head, head_dim]
  • +
  • gk: gradient of ko, [len, bs, q_head, head_dim]
  • +
  • gv: gradient of vo, [len, bs, q_head, head_dim]
  • +
  • qkv: input qkv
  • +
  • q_norm_weight:
  • +
  • k_norm_weight:
  • +
  • freqs:
  • +
  • eps:
  • +
  • transpose:
  • +
  • interleave:
  • +
+ +
Returns:
+ +
+

dqkv: gradient of qkv + dqw: gradient of q_norm_weight + dkw: gradient of k_norm_weight

+
+
+ + +
+
+ + \ No newline at end of file diff --git a/docs/linghe/utils/scatter.html b/docs/linghe/utils/scatter.html new file mode 100644 index 0000000..68e5f39 --- /dev/null +++ b/docs/linghe/utils/scatter.html @@ -0,0 +1,154 @@ + + + + + + + linghe.utils.scatter API documentation + + + + + + + + + +
+
+

+linghe.utils.scatter

+ +

Copyright (c) Ant Financial Service Group and its affiliates.

+
+ + + + +
+
+
+ + def + triton_aligned_scatter_add( x: torch.Tensor, outputs: torch.Tensor, indices: torch.Tensor, weights: Optional[torch.Tensor] = None): + + +
+ + +

scatter_add for megatron 0.11

+ +
Arguments:
+ +
    +
  • x: input tensor
  • +
  • outputs: output tensor
  • +
  • indices: gather indices
  • +
  • weights: rowwise weight, it is router prob in MoE router
  • +
+ +
Returns:
+ +
+

output tensor

+
+
+ + +
+
+
+ + def + triton_scatter_add(x, outputs, indices): + + +
+ + +

naive version of scatter add, very slow

+ +
Arguments:
+ +
    +
  • x: input tensor
  • +
  • outputs: output tensor
  • +
  • indices: indices
  • +
+ +
Returns:
+ +
+

outputs

+
+
+ + +
+
+
+ + def + triton_unpermute_with_mask_map(grad: torch.Tensor, row_id_map: torch.Tensor, probs: torch.Tensor): + + +
+ + +

scatter add with row id map

+ +
Arguments:
+ +
    +
  • grad: gradient tensor, [num_out_tokens, hidden_size]
  • +
  • row_id_map: row id map, [n_experts, num_tokens]
  • +
  • probs: [num_out_tokens]
  • +
+ +
Returns:
+ +
+

output: [num_tokens, hidden_size] + restore_probs: [num_tokens, num_experts]

+
+
+ + +
+
+ + \ No newline at end of file diff --git a/docs/linghe/utils/silu.html b/docs/linghe/utils/silu.html new file mode 100644 index 0000000..cbb5131 --- /dev/null +++ b/docs/linghe/utils/silu.html @@ -0,0 +1,273 @@ + + + + + + + linghe.utils.silu API documentation + + + + + + + + + +
+
+

+linghe.utils.silu

+ +

Copyright (c) Ant Financial Service Group and its affiliates.

+
+ + + + +
+
+
+ + def + triton_weighted_silu_forward(x, weight=None, out=None): + + +
+ + +

compute silu(x)*weight, used in bf16/fp16 training with MoE

+ +
Arguments:
+ +
    +
  • x: input tensor
  • +
  • weight: tokenwise weight
  • +
+ +
Returns:
+ +
+

out: output tensor

+
+
+ + +
+
+
+ + def + triton_weighted_silu_backward( g: torch.Tensor, x: torch.Tensor, weight: Optional[torch.Tensor] = None): + + +
+ + +

backward of triton_weighted_silu_forward

+ +
Arguments:
+ +
    +
  • g: gradient tensor
  • +
  • x: input tensor
  • +
  • weight: weight tensor
  • +
+ +
Returns:
+ +
+

dx: gradient of x + dw: gradient of weight

+
+
+ + +
+
+
+ + def + triton_silu_and_block_quant_forward(x, out=None, scale=None, round_scale=False, output_mode=2): + + +
+ + +

fused silu and blockwise quantization, used in shared expert

+ +
Arguments:
+ +
    +
  • x: input tensor
  • +
  • round_scale: whether round scale to power of 2
  • +
  • output_mode: one of {0, 1, 2} +0: only output non-transposed quantized tensor +1: only output transposed quantized tensor +2: output both
  • +
+ +
Returns:
+ +
+

out: quantized tensor + scale: quantization scale + transpose_output: quantized tensor of transposed output + transpose_scale: quantization scale of transposed output

+
+
+ + +
+
+
+ + def + triton_silu_and_block_quant_backward(g, x, round_scale=False): + + +
+ + +

backward of triton_silu_and_block_quant_forward

+ +
Arguments:
+ +
    +
  • g: gradient
  • +
  • x: input tensor
  • +
  • round_scale: whether round to power of 2
  • +
+ +
Returns:
+ +
+

dx: quantized non-transposed gradient + dx_scale: scales of quantization non-transposed gradient + transpose_dx: quantized transposed gradient + transpose_dx_scale: scales of quantization transposed gradient

+
+
+ + +
+
+
+ + def + triton_batch_weighted_silu_and_block_quant_forward( x, weight, counts, splits=None, out=None, scale=None, round_scale=False, output_mode=2): + + +
+ + +

silu and blockwise quantize activation in routed experts

+ +
Arguments:
+ +
    +
  • x: activation tensor in routed experts
  • +
  • weight: router prob tensor
  • +
  • counts: cuda tensor of token count per expert
  • +
  • splits: python int list of token count per expert
  • +
  • round_scale: whether round scale to power of 2
  • +
  • output_mode: one of {0, 1, 2} +0: only output non-transposed quantized tensor +1: only output transposed quantized tensor +2: output both
  • +
+ +
Returns:
+ +
+

out: quantized tensor + scale: quantization scale + transpose_output: quantized tensor of transposed output + transpose_scale: quantization scale of transposed output

+
+
+ + +
+
+
+ + def + triton_batch_weighted_silu_and_block_quant_backward(g, x, weight, counts, splits=None, round_scale=False): + + +
+ + +

backward of triton_batch_weighted_silu_and_block_quant_forward

+ +
Arguments:
+ +
    +
  • g: gradient
  • +
  • x: input tensor
  • +
  • weight: router prob tensor
  • +
  • counts: cuda tensor of token count per expert
  • +
  • splits: python int list of token count per expert
  • +
  • round_scale: whether round scale to power of 2
  • +
+ +
Returns:
+ +
+

dx: quantized non-transposed gradient + dx_scale: scales of quantization non-transposed gradient + dw: gradient of weight + transpose_dx: quantized transposed gradient + transpose_dx_scale: scales of quantization transposed gradient

+
+
+ + +
+
+ + \ No newline at end of file diff --git a/docs/linghe/utils/transpose.html b/docs/linghe/utils/transpose.html new file mode 100644 index 0000000..b278ac2 --- /dev/null +++ b/docs/linghe/utils/transpose.html @@ -0,0 +1,184 @@ + + + + + + + linghe.utils.transpose API documentation + + + + + + + + + +
+
+

+linghe.utils.transpose

+ +

Copyright (c) Ant Financial Service Group and its affiliates.

+
+ + + + +
+
+
+ + def + triton_transpose( x: torch.Tensor, dim0: Optional[int] = None, dim1: Optional[int] = None): + + +
+ + +

transpose x with dim0 and dim1

+ +
Arguments:
+ +
    +
  • x: input tensor
  • +
  • dim0: dim 0
  • +
  • dim1: dim 1
  • +
+ +
Returns:
+ +
+

transposed tensor

+
+
+ + +
+
+
+ + def + triton_transpose_and_pad(x, out=None, pad=True): + + +
+ + +

transpose x and padding the column size to be mutiplier of 32, +it is used for calculated gradient of weight with torch._scaled__mm

+ +
Arguments:
+ +
    +
  • x: input tensor
  • +
  • out:
  • +
  • pad: whether need padding
  • +
+ +
Returns:
+ +
+

out: output tensor

+
+
+ + +
+
+
+ + def + triton_batch_transpose(xs, xts=None): + + +
+ + +

batch transpose x

+ +
Arguments:
+ +
    +
  • xs: input tensor list, [M, N]*expert
  • +
+ +
Returns:
+ +
+

xts: output tensor list, [N,M]*expert

+
+
+ + +
+
+
+ + def + triton_batch_transpose_and_pad(x, count_list, x_t=None, pad=True): + + +
+ + +

transpose and pad each tensor stored in x

+ +
Arguments:
+ +
    +
  • x: [sum(bs), N]
  • +
  • count_list: a python list of token count
  • +
  • pad: whether pad to mutiplier of 32, +padding value should be filled with 0 if padded
  • +
+ +
Returns:
+ +
+

x_t: output tensor

+
+
+ + +
+
+ + \ No newline at end of file diff --git a/docs/search.js b/docs/search.js new file mode 100644 index 0000000..23741f0 --- /dev/null +++ b/docs/search.js @@ -0,0 +1,46 @@ +window.pdocSearch = (function(){ +/** elasticlunr - http://weixsong.github.io * Copyright (C) 2017 Oliver Nightingale * Copyright (C) 2017 Wei Song * MIT Licensed */!function(){function e(e){if(null===e||"object"!=typeof e)return e;var t=e.constructor();for(var n in e)e.hasOwnProperty(n)&&(t[n]=e[n]);return t}var t=function(e){var n=new t.Index;return n.pipeline.add(t.trimmer,t.stopWordFilter,t.stemmer),e&&e.call(n,n),n};t.version="0.9.5",lunr=t,t.utils={},t.utils.warn=function(e){return function(t){e.console&&console.warn&&console.warn(t)}}(this),t.utils.toString=function(e){return void 0===e||null===e?"":e.toString()},t.EventEmitter=function(){this.events={}},t.EventEmitter.prototype.addListener=function(){var e=Array.prototype.slice.call(arguments),t=e.pop(),n=e;if("function"!=typeof t)throw new TypeError("last argument must be a function");n.forEach(function(e){this.hasHandler(e)||(this.events[e]=[]),this.events[e].push(t)},this)},t.EventEmitter.prototype.removeListener=function(e,t){if(this.hasHandler(e)){var n=this.events[e].indexOf(t);-1!==n&&(this.events[e].splice(n,1),0==this.events[e].length&&delete this.events[e])}},t.EventEmitter.prototype.emit=function(e){if(this.hasHandler(e)){var t=Array.prototype.slice.call(arguments,1);this.events[e].forEach(function(e){e.apply(void 0,t)},this)}},t.EventEmitter.prototype.hasHandler=function(e){return e in this.events},t.tokenizer=function(e){if(!arguments.length||null===e||void 0===e)return[];if(Array.isArray(e)){var n=e.filter(function(e){return null===e||void 0===e?!1:!0});n=n.map(function(e){return t.utils.toString(e).toLowerCase()});var i=[];return n.forEach(function(e){var n=e.split(t.tokenizer.seperator);i=i.concat(n)},this),i}return e.toString().trim().toLowerCase().split(t.tokenizer.seperator)},t.tokenizer.defaultSeperator=/[\s\-]+/,t.tokenizer.seperator=t.tokenizer.defaultSeperator,t.tokenizer.setSeperator=function(e){null!==e&&void 0!==e&&"object"==typeof e&&(t.tokenizer.seperator=e)},t.tokenizer.resetSeperator=function(){t.tokenizer.seperator=t.tokenizer.defaultSeperator},t.tokenizer.getSeperator=function(){return t.tokenizer.seperator},t.Pipeline=function(){this._queue=[]},t.Pipeline.registeredFunctions={},t.Pipeline.registerFunction=function(e,n){n in t.Pipeline.registeredFunctions&&t.utils.warn("Overwriting existing registered function: "+n),e.label=n,t.Pipeline.registeredFunctions[n]=e},t.Pipeline.getRegisteredFunction=function(e){return e in t.Pipeline.registeredFunctions!=!0?null:t.Pipeline.registeredFunctions[e]},t.Pipeline.warnIfFunctionNotRegistered=function(e){var n=e.label&&e.label in this.registeredFunctions;n||t.utils.warn("Function is not registered with pipeline. This may cause problems when serialising the index.\n",e)},t.Pipeline.load=function(e){var n=new t.Pipeline;return e.forEach(function(e){var i=t.Pipeline.getRegisteredFunction(e);if(!i)throw new Error("Cannot load un-registered function: "+e);n.add(i)}),n},t.Pipeline.prototype.add=function(){var e=Array.prototype.slice.call(arguments);e.forEach(function(e){t.Pipeline.warnIfFunctionNotRegistered(e),this._queue.push(e)},this)},t.Pipeline.prototype.after=function(e,n){t.Pipeline.warnIfFunctionNotRegistered(n);var i=this._queue.indexOf(e);if(-1===i)throw new Error("Cannot find existingFn");this._queue.splice(i+1,0,n)},t.Pipeline.prototype.before=function(e,n){t.Pipeline.warnIfFunctionNotRegistered(n);var i=this._queue.indexOf(e);if(-1===i)throw new Error("Cannot find existingFn");this._queue.splice(i,0,n)},t.Pipeline.prototype.remove=function(e){var t=this._queue.indexOf(e);-1!==t&&this._queue.splice(t,1)},t.Pipeline.prototype.run=function(e){for(var t=[],n=e.length,i=this._queue.length,o=0;n>o;o++){for(var r=e[o],s=0;i>s&&(r=this._queue[s](r,o,e),void 0!==r&&null!==r);s++);void 0!==r&&null!==r&&t.push(r)}return t},t.Pipeline.prototype.reset=function(){this._queue=[]},t.Pipeline.prototype.get=function(){return this._queue},t.Pipeline.prototype.toJSON=function(){return this._queue.map(function(e){return t.Pipeline.warnIfFunctionNotRegistered(e),e.label})},t.Index=function(){this._fields=[],this._ref="id",this.pipeline=new t.Pipeline,this.documentStore=new t.DocumentStore,this.index={},this.eventEmitter=new t.EventEmitter,this._idfCache={},this.on("add","remove","update",function(){this._idfCache={}}.bind(this))},t.Index.prototype.on=function(){var e=Array.prototype.slice.call(arguments);return this.eventEmitter.addListener.apply(this.eventEmitter,e)},t.Index.prototype.off=function(e,t){return this.eventEmitter.removeListener(e,t)},t.Index.load=function(e){e.version!==t.version&&t.utils.warn("version mismatch: current "+t.version+" importing "+e.version);var n=new this;n._fields=e.fields,n._ref=e.ref,n.documentStore=t.DocumentStore.load(e.documentStore),n.pipeline=t.Pipeline.load(e.pipeline),n.index={};for(var i in e.index)n.index[i]=t.InvertedIndex.load(e.index[i]);return n},t.Index.prototype.addField=function(e){return this._fields.push(e),this.index[e]=new t.InvertedIndex,this},t.Index.prototype.setRef=function(e){return this._ref=e,this},t.Index.prototype.saveDocument=function(e){return this.documentStore=new t.DocumentStore(e),this},t.Index.prototype.addDoc=function(e,n){if(e){var n=void 0===n?!0:n,i=e[this._ref];this.documentStore.addDoc(i,e),this._fields.forEach(function(n){var o=this.pipeline.run(t.tokenizer(e[n]));this.documentStore.addFieldLength(i,n,o.length);var r={};o.forEach(function(e){e in r?r[e]+=1:r[e]=1},this);for(var s in r){var u=r[s];u=Math.sqrt(u),this.index[n].addToken(s,{ref:i,tf:u})}},this),n&&this.eventEmitter.emit("add",e,this)}},t.Index.prototype.removeDocByRef=function(e){if(e&&this.documentStore.isDocStored()!==!1&&this.documentStore.hasDoc(e)){var t=this.documentStore.getDoc(e);this.removeDoc(t,!1)}},t.Index.prototype.removeDoc=function(e,n){if(e){var n=void 0===n?!0:n,i=e[this._ref];this.documentStore.hasDoc(i)&&(this.documentStore.removeDoc(i),this._fields.forEach(function(n){var o=this.pipeline.run(t.tokenizer(e[n]));o.forEach(function(e){this.index[n].removeToken(e,i)},this)},this),n&&this.eventEmitter.emit("remove",e,this))}},t.Index.prototype.updateDoc=function(e,t){var t=void 0===t?!0:t;this.removeDocByRef(e[this._ref],!1),this.addDoc(e,!1),t&&this.eventEmitter.emit("update",e,this)},t.Index.prototype.idf=function(e,t){var n="@"+t+"/"+e;if(Object.prototype.hasOwnProperty.call(this._idfCache,n))return this._idfCache[n];var i=this.index[t].getDocFreq(e),o=1+Math.log(this.documentStore.length/(i+1));return this._idfCache[n]=o,o},t.Index.prototype.getFields=function(){return this._fields.slice()},t.Index.prototype.search=function(e,n){if(!e)return[];e="string"==typeof e?{any:e}:JSON.parse(JSON.stringify(e));var i=null;null!=n&&(i=JSON.stringify(n));for(var o=new t.Configuration(i,this.getFields()).get(),r={},s=Object.keys(e),u=0;u0&&t.push(e);for(var i in n)"docs"!==i&&"df"!==i&&this.expandToken(e+i,t,n[i]);return t},t.InvertedIndex.prototype.toJSON=function(){return{root:this.root}},t.Configuration=function(e,n){var e=e||"";if(void 0==n||null==n)throw new Error("fields should not be null");this.config={};var i;try{i=JSON.parse(e),this.buildUserConfig(i,n)}catch(o){t.utils.warn("user configuration parse failed, will use default configuration"),this.buildDefaultConfig(n)}},t.Configuration.prototype.buildDefaultConfig=function(e){this.reset(),e.forEach(function(e){this.config[e]={boost:1,bool:"OR",expand:!1}},this)},t.Configuration.prototype.buildUserConfig=function(e,n){var i="OR",o=!1;if(this.reset(),"bool"in e&&(i=e.bool||i),"expand"in e&&(o=e.expand||o),"fields"in e)for(var r in e.fields)if(n.indexOf(r)>-1){var s=e.fields[r],u=o;void 0!=s.expand&&(u=s.expand),this.config[r]={boost:s.boost||0===s.boost?s.boost:1,bool:s.bool||i,expand:u}}else t.utils.warn("field name in user configuration not found in index instance fields");else this.addAllFields2UserConfig(i,o,n)},t.Configuration.prototype.addAllFields2UserConfig=function(e,t,n){n.forEach(function(n){this.config[n]={boost:1,bool:e,expand:t}},this)},t.Configuration.prototype.get=function(){return this.config},t.Configuration.prototype.reset=function(){this.config={}},lunr.SortedSet=function(){this.length=0,this.elements=[]},lunr.SortedSet.load=function(e){var t=new this;return t.elements=e,t.length=e.length,t},lunr.SortedSet.prototype.add=function(){var e,t;for(e=0;e1;){if(r===e)return o;e>r&&(t=o),r>e&&(n=o),i=n-t,o=t+Math.floor(i/2),r=this.elements[o]}return r===e?o:-1},lunr.SortedSet.prototype.locationFor=function(e){for(var t=0,n=this.elements.length,i=n-t,o=t+Math.floor(i/2),r=this.elements[o];i>1;)e>r&&(t=o),r>e&&(n=o),i=n-t,o=t+Math.floor(i/2),r=this.elements[o];return r>e?o:e>r?o+1:void 0},lunr.SortedSet.prototype.intersect=function(e){for(var t=new lunr.SortedSet,n=0,i=0,o=this.length,r=e.length,s=this.elements,u=e.elements;;){if(n>o-1||i>r-1)break;s[n]!==u[i]?s[n]u[i]&&i++:(t.add(s[n]),n++,i++)}return t},lunr.SortedSet.prototype.clone=function(){var e=new lunr.SortedSet;return e.elements=this.toArray(),e.length=e.elements.length,e},lunr.SortedSet.prototype.union=function(e){var t,n,i;this.length>=e.length?(t=this,n=e):(t=e,n=this),i=t.clone();for(var o=0,r=n.toArray();o

\n"}, {"fullname": "linghe.facade", "modulename": "linghe.facade", "kind": "module", "doc": "

\n"}, {"fullname": "linghe.facade.add", "modulename": "linghe.facade.add", "kind": "module", "doc": "

Copyright (c) Ant Financial Service Group and its affiliates.

\n"}, {"fullname": "linghe.facade.add.InplaceAddFunction", "modulename": "linghe.facade.add", "qualname": "InplaceAddFunction", "kind": "class", "doc": "

Base class to create custom autograd.Function.

\n\n

To create a custom autograd.Function, subclass this class and implement\nthe :meth:forward and :meth:backward static methods. Then, to use your custom\nop in the forward pass, call the class method apply. Do not call\n:meth:forward directly.

\n\n

To ensure correctness and best performance, make sure you are calling the\ncorrect methods on ctx and validating your backward function using\n:func:torch.autograd.gradcheck.

\n\n

See :ref:extending-autograd for more details on how to use this class.

\n\n

Examples::

\n\n
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)\n>>> class Exp(Function):\n>>>     @staticmethod\n>>>     def forward(ctx, i):\n>>>         result = i.exp()\n>>>         ctx.save_for_backward(result)\n>>>         return result\n>>>\n>>>     @staticmethod\n>>>     def backward(ctx, grad_output):\n>>>         result, = ctx.saved_tensors\n>>>         return grad_output * result\n>>>\n>>> # Use it by calling the apply method:\n>>> # xdoctest: +SKIP\n>>> output = Exp.apply(input)\n
\n", "bases": "torch.autograd.function.Function"}, {"fullname": "linghe.facade.add.InplaceAddFunction.forward", "modulename": "linghe.facade.add", "qualname": "InplaceAddFunction.forward", "kind": "function", "doc": "

Define the forward of the custom autograd Function.

\n\n

This function is to be overridden by all subclasses.\nThere are two ways to define forward:

\n\n

Usage 1 (Combined forward and ctx)::

\n\n
@staticmethod\ndef forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:\n    pass\n
\n\n
    \n
  • It must accept a context ctx as the first argument, followed by any\nnumber of arguments (tensors or other types).
  • \n
  • See :ref:combining-forward-context for more details
  • \n
\n\n

Usage 2 (Separate forward and ctx)::

\n\n
@staticmethod\ndef forward(*args: Any, **kwargs: Any) -> Any:\n    pass\n\n@staticmethod\ndef setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:\n    pass\n
\n\n
    \n
  • The forward no longer accepts a ctx argument.
  • \n
  • Instead, you must also override the :meth:torch.autograd.Function.setup_context\nstaticmethod to handle setting up the ctx object.\noutput is the output of the forward, inputs are a Tuple of inputs\nto the forward.
  • \n
  • See :ref:extending-autograd for more details
  • \n
\n\n

The context can be used to store arbitrary data that can be then\nretrieved during the backward pass. Tensors should not be stored\ndirectly on ctx (though this is not currently enforced for\nbackward compatibility). Instead, tensors should be saved either with\n:func:ctx.save_for_backward if they are intended to be used in\nbackward (equivalently, vjp) or :func:ctx.save_for_forward\nif they are intended to be used for in jvp.

\n", "signature": "(ctx, x, y):", "funcdef": "def"}, {"fullname": "linghe.facade.add.InplaceAddFunction.backward", "modulename": "linghe.facade.add", "qualname": "InplaceAddFunction.backward", "kind": "function", "doc": "

Define a formula for differentiating the operation with backward mode automatic differentiation.

\n\n

This function is to be overridden by all subclasses.\n(Defining this function is equivalent to defining the vjp function.)

\n\n

It must accept a context :attr:ctx as the first argument, followed by\nas many outputs as the :func:forward returned (None will be passed in\nfor non tensor outputs of the forward function),\nand it should return as many tensors, as there were inputs to\n:func:forward. Each argument is the gradient w.r.t the given output,\nand each returned value should be the gradient w.r.t. the\ncorresponding input. If an input is not a Tensor or is a Tensor not\nrequiring grads, you can just pass None as a gradient for that input.

\n\n

The context can be used to retrieve tensors saved during the forward\npass. It also has an attribute :attr:ctx.needs_input_grad as a tuple\nof booleans representing whether each input needs gradient. E.g.,\n:func:backward will have ctx.needs_input_grad[0] = True if the\nfirst input to :func:forward needs gradient computed w.r.t. the\noutput.

\n", "signature": "(ctx, grad_output):", "funcdef": "def"}, {"fullname": "linghe.facade.fp32_linear", "modulename": "linghe.facade.fp32_linear", "kind": "module", "doc": "

Copyright (c) Ant Financial Service Group and its affiliates.

\n"}, {"fullname": "linghe.facade.fp32_linear.FusedFp32GEMM", "modulename": "linghe.facade.fp32_linear", "qualname": "FusedFp32GEMM", "kind": "class", "doc": "

Base class to create custom autograd.Function.

\n\n

To create a custom autograd.Function, subclass this class and implement\nthe :meth:forward and :meth:backward static methods. Then, to use your custom\nop in the forward pass, call the class method apply. Do not call\n:meth:forward directly.

\n\n

To ensure correctness and best performance, make sure you are calling the\ncorrect methods on ctx and validating your backward function using\n:func:torch.autograd.gradcheck.

\n\n

See :ref:extending-autograd for more details on how to use this class.

\n\n

Examples::

\n\n
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)\n>>> class Exp(Function):\n>>>     @staticmethod\n>>>     def forward(ctx, i):\n>>>         result = i.exp()\n>>>         ctx.save_for_backward(result)\n>>>         return result\n>>>\n>>>     @staticmethod\n>>>     def backward(ctx, grad_output):\n>>>         result, = ctx.saved_tensors\n>>>         return grad_output * result\n>>>\n>>> # Use it by calling the apply method:\n>>> # xdoctest: +SKIP\n>>> output = Exp.apply(input)\n
\n", "bases": "torch.autograd.function.Function"}, {"fullname": "linghe.facade.fp32_linear.FusedFp32GEMM.forward", "modulename": "linghe.facade.fp32_linear", "qualname": "FusedFp32GEMM.forward", "kind": "function", "doc": "

Define the forward of the custom autograd Function.

\n\n

This function is to be overridden by all subclasses.\nThere are two ways to define forward:

\n\n

Usage 1 (Combined forward and ctx)::

\n\n
@staticmethod\ndef forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:\n    pass\n
\n\n
    \n
  • It must accept a context ctx as the first argument, followed by any\nnumber of arguments (tensors or other types).
  • \n
  • See :ref:combining-forward-context for more details
  • \n
\n\n

Usage 2 (Separate forward and ctx)::

\n\n
@staticmethod\ndef forward(*args: Any, **kwargs: Any) -> Any:\n    pass\n\n@staticmethod\ndef setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:\n    pass\n
\n\n
    \n
  • The forward no longer accepts a ctx argument.
  • \n
  • Instead, you must also override the :meth:torch.autograd.Function.setup_context\nstaticmethod to handle setting up the ctx object.\noutput is the output of the forward, inputs are a Tuple of inputs\nto the forward.
  • \n
  • See :ref:extending-autograd for more details
  • \n
\n\n

The context can be used to store arbitrary data that can be then\nretrieved during the backward pass. Tensors should not be stored\ndirectly on ctx (though this is not currently enforced for\nbackward compatibility). Instead, tensors should be saved either with\n:func:ctx.save_for_backward if they are intended to be used in\nbackward (equivalently, vjp) or :func:ctx.save_for_forward\nif they are intended to be used for in jvp.

\n", "signature": "(ctx, input, weight):", "funcdef": "def"}, {"fullname": "linghe.facade.fp32_linear.FusedFp32GEMM.backward", "modulename": "linghe.facade.fp32_linear", "qualname": "FusedFp32GEMM.backward", "kind": "function", "doc": "

Define a formula for differentiating the operation with backward mode automatic differentiation.

\n\n

This function is to be overridden by all subclasses.\n(Defining this function is equivalent to defining the vjp function.)

\n\n

It must accept a context :attr:ctx as the first argument, followed by\nas many outputs as the :func:forward returned (None will be passed in\nfor non tensor outputs of the forward function),\nand it should return as many tensors, as there were inputs to\n:func:forward. Each argument is the gradient w.r.t the given output,\nand each returned value should be the gradient w.r.t. the\ncorresponding input. If an input is not a Tensor or is a Tensor not\nrequiring grads, you can just pass None as a gradient for that input.

\n\n

The context can be used to retrieve tensors saved during the forward\npass. It also has an attribute :attr:ctx.needs_input_grad as a tuple\nof booleans representing whether each input needs gradient. E.g.,\n:func:backward will have ctx.needs_input_grad[0] = True if the\nfirst input to :func:forward needs gradient computed w.r.t. the\noutput.

\n", "signature": "(ctx, grad_output):", "funcdef": "def"}, {"fullname": "linghe.facade.loss", "modulename": "linghe.facade.loss", "kind": "module", "doc": "

Copyright (c) Ant Financial Service Group and its affiliates.

\n"}, {"fullname": "linghe.facade.loss.SoftmaxCrossEntropyFunction", "modulename": "linghe.facade.loss", "qualname": "SoftmaxCrossEntropyFunction", "kind": "class", "doc": "

Base class to create custom autograd.Function.

\n\n

To create a custom autograd.Function, subclass this class and implement\nthe :meth:forward and :meth:backward static methods. Then, to use your custom\nop in the forward pass, call the class method apply. Do not call\n:meth:forward directly.

\n\n

To ensure correctness and best performance, make sure you are calling the\ncorrect methods on ctx and validating your backward function using\n:func:torch.autograd.gradcheck.

\n\n

See :ref:extending-autograd for more details on how to use this class.

\n\n

Examples::

\n\n
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)\n>>> class Exp(Function):\n>>>     @staticmethod\n>>>     def forward(ctx, i):\n>>>         result = i.exp()\n>>>         ctx.save_for_backward(result)\n>>>         return result\n>>>\n>>>     @staticmethod\n>>>     def backward(ctx, grad_output):\n>>>         result, = ctx.saved_tensors\n>>>         return grad_output * result\n>>>\n>>> # Use it by calling the apply method:\n>>> # xdoctest: +SKIP\n>>> output = Exp.apply(input)\n
\n", "bases": "torch.autograd.function.Function"}, {"fullname": "linghe.facade.loss.SoftmaxCrossEntropyFunction.forward", "modulename": "linghe.facade.loss", "qualname": "SoftmaxCrossEntropyFunction.forward", "kind": "function", "doc": "

Define the forward of the custom autograd Function.

\n\n

This function is to be overridden by all subclasses.\nThere are two ways to define forward:

\n\n

Usage 1 (Combined forward and ctx)::

\n\n
@staticmethod\ndef forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:\n    pass\n
\n\n
    \n
  • It must accept a context ctx as the first argument, followed by any\nnumber of arguments (tensors or other types).
  • \n
  • See :ref:combining-forward-context for more details
  • \n
\n\n

Usage 2 (Separate forward and ctx)::

\n\n
@staticmethod\ndef forward(*args: Any, **kwargs: Any) -> Any:\n    pass\n\n@staticmethod\ndef setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:\n    pass\n
\n\n
    \n
  • The forward no longer accepts a ctx argument.
  • \n
  • Instead, you must also override the :meth:torch.autograd.Function.setup_context\nstaticmethod to handle setting up the ctx object.\noutput is the output of the forward, inputs are a Tuple of inputs\nto the forward.
  • \n
  • See :ref:extending-autograd for more details
  • \n
\n\n

The context can be used to store arbitrary data that can be then\nretrieved during the backward pass. Tensors should not be stored\ndirectly on ctx (though this is not currently enforced for\nbackward compatibility). Instead, tensors should be saved either with\n:func:ctx.save_for_backward if they are intended to be used in\nbackward (equivalently, vjp) or :func:ctx.save_for_forward\nif they are intended to be used for in jvp.

\n", "signature": "(ctx, logits, labels, inplace=False):", "funcdef": "def"}, {"fullname": "linghe.facade.loss.SoftmaxCrossEntropyFunction.backward", "modulename": "linghe.facade.loss", "qualname": "SoftmaxCrossEntropyFunction.backward", "kind": "function", "doc": "

Define a formula for differentiating the operation with backward mode automatic differentiation.

\n\n

This function is to be overridden by all subclasses.\n(Defining this function is equivalent to defining the vjp function.)

\n\n

It must accept a context :attr:ctx as the first argument, followed by\nas many outputs as the :func:forward returned (None will be passed in\nfor non tensor outputs of the forward function),\nand it should return as many tensors, as there were inputs to\n:func:forward. Each argument is the gradient w.r.t the given output,\nand each returned value should be the gradient w.r.t. the\ncorresponding input. If an input is not a Tensor or is a Tensor not\nrequiring grads, you can just pass None as a gradient for that input.

\n\n

The context can be used to retrieve tensors saved during the forward\npass. It also has an attribute :attr:ctx.needs_input_grad as a tuple\nof booleans representing whether each input needs gradient. E.g.,\n:func:backward will have ctx.needs_input_grad[0] = True if the\nfirst input to :func:forward needs gradient computed w.r.t. the\noutput.

\n", "signature": "(ctx, grad_output):", "funcdef": "def"}, {"fullname": "linghe.facade.loss.GradScalingFunction", "modulename": "linghe.facade.loss", "qualname": "GradScalingFunction", "kind": "class", "doc": "

Base class to create custom autograd.Function.

\n\n

To create a custom autograd.Function, subclass this class and implement\nthe :meth:forward and :meth:backward static methods. Then, to use your custom\nop in the forward pass, call the class method apply. Do not call\n:meth:forward directly.

\n\n

To ensure correctness and best performance, make sure you are calling the\ncorrect methods on ctx and validating your backward function using\n:func:torch.autograd.gradcheck.

\n\n

See :ref:extending-autograd for more details on how to use this class.

\n\n

Examples::

\n\n
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)\n>>> class Exp(Function):\n>>>     @staticmethod\n>>>     def forward(ctx, i):\n>>>         result = i.exp()\n>>>         ctx.save_for_backward(result)\n>>>         return result\n>>>\n>>>     @staticmethod\n>>>     def backward(ctx, grad_output):\n>>>         result, = ctx.saved_tensors\n>>>         return grad_output * result\n>>>\n>>> # Use it by calling the apply method:\n>>> # xdoctest: +SKIP\n>>> output = Exp.apply(input)\n
\n", "bases": "torch.autograd.function.Function"}, {"fullname": "linghe.facade.loss.GradScalingFunction.forward", "modulename": "linghe.facade.loss", "qualname": "GradScalingFunction.forward", "kind": "function", "doc": "

Define the forward of the custom autograd Function.

\n\n

This function is to be overridden by all subclasses.\nThere are two ways to define forward:

\n\n

Usage 1 (Combined forward and ctx)::

\n\n
@staticmethod\ndef forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:\n    pass\n
\n\n
    \n
  • It must accept a context ctx as the first argument, followed by any\nnumber of arguments (tensors or other types).
  • \n
  • See :ref:combining-forward-context for more details
  • \n
\n\n

Usage 2 (Separate forward and ctx)::

\n\n
@staticmethod\ndef forward(*args: Any, **kwargs: Any) -> Any:\n    pass\n\n@staticmethod\ndef setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:\n    pass\n
\n\n
    \n
  • The forward no longer accepts a ctx argument.
  • \n
  • Instead, you must also override the :meth:torch.autograd.Function.setup_context\nstaticmethod to handle setting up the ctx object.\noutput is the output of the forward, inputs are a Tuple of inputs\nto the forward.
  • \n
  • See :ref:extending-autograd for more details
  • \n
\n\n

The context can be used to store arbitrary data that can be then\nretrieved during the backward pass. Tensors should not be stored\ndirectly on ctx (though this is not currently enforced for\nbackward compatibility). Instead, tensors should be saved either with\n:func:ctx.save_for_backward if they are intended to be used in\nbackward (equivalently, vjp) or :func:ctx.save_for_forward\nif they are intended to be used for in jvp.

\n", "signature": "(ctx, x, coef=0.2):", "funcdef": "def"}, {"fullname": "linghe.facade.loss.GradScalingFunction.backward", "modulename": "linghe.facade.loss", "qualname": "GradScalingFunction.backward", "kind": "function", "doc": "

Define a formula for differentiating the operation with backward mode automatic differentiation.

\n\n

This function is to be overridden by all subclasses.\n(Defining this function is equivalent to defining the vjp function.)

\n\n

It must accept a context :attr:ctx as the first argument, followed by\nas many outputs as the :func:forward returned (None will be passed in\nfor non tensor outputs of the forward function),\nand it should return as many tensors, as there were inputs to\n:func:forward. Each argument is the gradient w.r.t the given output,\nand each returned value should be the gradient w.r.t. the\ncorresponding input. If an input is not a Tensor or is a Tensor not\nrequiring grads, you can just pass None as a gradient for that input.

\n\n

The context can be used to retrieve tensors saved during the forward\npass. It also has an attribute :attr:ctx.needs_input_grad as a tuple\nof booleans representing whether each input needs gradient. E.g.,\n:func:backward will have ctx.needs_input_grad[0] = True if the\nfirst input to :func:forward needs gradient computed w.r.t. the\noutput.

\n", "signature": "(ctx, grad_output):", "funcdef": "def"}, {"fullname": "linghe.facade.norm", "modulename": "linghe.facade.norm", "kind": "module", "doc": "

Copyright (c) Ant Financial Service Group and its affiliates.

\n"}, {"fullname": "linghe.facade.norm.RMSNormFunction", "modulename": "linghe.facade.norm", "qualname": "RMSNormFunction", "kind": "class", "doc": "

Base class to create custom autograd.Function.

\n\n

To create a custom autograd.Function, subclass this class and implement\nthe :meth:forward and :meth:backward static methods. Then, to use your custom\nop in the forward pass, call the class method apply. Do not call\n:meth:forward directly.

\n\n

To ensure correctness and best performance, make sure you are calling the\ncorrect methods on ctx and validating your backward function using\n:func:torch.autograd.gradcheck.

\n\n

See :ref:extending-autograd for more details on how to use this class.

\n\n

Examples::

\n\n
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)\n>>> class Exp(Function):\n>>>     @staticmethod\n>>>     def forward(ctx, i):\n>>>         result = i.exp()\n>>>         ctx.save_for_backward(result)\n>>>         return result\n>>>\n>>>     @staticmethod\n>>>     def backward(ctx, grad_output):\n>>>         result, = ctx.saved_tensors\n>>>         return grad_output * result\n>>>\n>>> # Use it by calling the apply method:\n>>> # xdoctest: +SKIP\n>>> output = Exp.apply(input)\n
\n", "bases": "torch.autograd.function.Function"}, {"fullname": "linghe.facade.norm.RMSNormFunction.forward", "modulename": "linghe.facade.norm", "qualname": "RMSNormFunction.forward", "kind": "function", "doc": "

Define the forward of the custom autograd Function.

\n\n

This function is to be overridden by all subclasses.\nThere are two ways to define forward:

\n\n

Usage 1 (Combined forward and ctx)::

\n\n
@staticmethod\ndef forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:\n    pass\n
\n\n
    \n
  • It must accept a context ctx as the first argument, followed by any\nnumber of arguments (tensors or other types).
  • \n
  • See :ref:combining-forward-context for more details
  • \n
\n\n

Usage 2 (Separate forward and ctx)::

\n\n
@staticmethod\ndef forward(*args: Any, **kwargs: Any) -> Any:\n    pass\n\n@staticmethod\ndef setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:\n    pass\n
\n\n
    \n
  • The forward no longer accepts a ctx argument.
  • \n
  • Instead, you must also override the :meth:torch.autograd.Function.setup_context\nstaticmethod to handle setting up the ctx object.\noutput is the output of the forward, inputs are a Tuple of inputs\nto the forward.
  • \n
  • See :ref:extending-autograd for more details
  • \n
\n\n

The context can be used to store arbitrary data that can be then\nretrieved during the backward pass. Tensors should not be stored\ndirectly on ctx (though this is not currently enforced for\nbackward compatibility). Instead, tensors should be saved either with\n:func:ctx.save_for_backward if they are intended to be used in\nbackward (equivalently, vjp) or :func:ctx.save_for_forward\nif they are intended to be used for in jvp.

\n", "signature": "(ctx, x, weight, eps=1e-06):", "funcdef": "def"}, {"fullname": "linghe.facade.norm.RMSNormFunction.backward", "modulename": "linghe.facade.norm", "qualname": "RMSNormFunction.backward", "kind": "function", "doc": "

Define a formula for differentiating the operation with backward mode automatic differentiation.

\n\n

This function is to be overridden by all subclasses.\n(Defining this function is equivalent to defining the vjp function.)

\n\n

It must accept a context :attr:ctx as the first argument, followed by\nas many outputs as the :func:forward returned (None will be passed in\nfor non tensor outputs of the forward function),\nand it should return as many tensors, as there were inputs to\n:func:forward. Each argument is the gradient w.r.t the given output,\nand each returned value should be the gradient w.r.t. the\ncorresponding input. If an input is not a Tensor or is a Tensor not\nrequiring grads, you can just pass None as a gradient for that input.

\n\n

The context can be used to retrieve tensors saved during the forward\npass. It also has an attribute :attr:ctx.needs_input_grad as a tuple\nof booleans representing whether each input needs gradient. E.g.,\n:func:backward will have ctx.needs_input_grad[0] = True if the\nfirst input to :func:forward needs gradient computed w.r.t. the\noutput.

\n", "signature": "(ctx, dy):", "funcdef": "def"}, {"fullname": "linghe.facade.norm.GroupNormGateFunction", "modulename": "linghe.facade.norm", "qualname": "GroupNormGateFunction", "kind": "class", "doc": "

Base class to create custom autograd.Function.

\n\n

To create a custom autograd.Function, subclass this class and implement\nthe :meth:forward and :meth:backward static methods. Then, to use your custom\nop in the forward pass, call the class method apply. Do not call\n:meth:forward directly.

\n\n

To ensure correctness and best performance, make sure you are calling the\ncorrect methods on ctx and validating your backward function using\n:func:torch.autograd.gradcheck.

\n\n

See :ref:extending-autograd for more details on how to use this class.

\n\n

Examples::

\n\n
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)\n>>> class Exp(Function):\n>>>     @staticmethod\n>>>     def forward(ctx, i):\n>>>         result = i.exp()\n>>>         ctx.save_for_backward(result)\n>>>         return result\n>>>\n>>>     @staticmethod\n>>>     def backward(ctx, grad_output):\n>>>         result, = ctx.saved_tensors\n>>>         return grad_output * result\n>>>\n>>> # Use it by calling the apply method:\n>>> # xdoctest: +SKIP\n>>> output = Exp.apply(input)\n
\n", "bases": "torch.autograd.function.Function"}, {"fullname": "linghe.facade.norm.GroupNormGateFunction.forward", "modulename": "linghe.facade.norm", "qualname": "GroupNormGateFunction.forward", "kind": "function", "doc": "

Define the forward of the custom autograd Function.

\n\n

This function is to be overridden by all subclasses.\nThere are two ways to define forward:

\n\n

Usage 1 (Combined forward and ctx)::

\n\n
@staticmethod\ndef forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:\n    pass\n
\n\n
    \n
  • It must accept a context ctx as the first argument, followed by any\nnumber of arguments (tensors or other types).
  • \n
  • See :ref:combining-forward-context for more details
  • \n
\n\n

Usage 2 (Separate forward and ctx)::

\n\n
@staticmethod\ndef forward(*args: Any, **kwargs: Any) -> Any:\n    pass\n\n@staticmethod\ndef setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:\n    pass\n
\n\n
    \n
  • The forward no longer accepts a ctx argument.
  • \n
  • Instead, you must also override the :meth:torch.autograd.Function.setup_context\nstaticmethod to handle setting up the ctx object.\noutput is the output of the forward, inputs are a Tuple of inputs\nto the forward.
  • \n
  • See :ref:extending-autograd for more details
  • \n
\n\n

The context can be used to store arbitrary data that can be then\nretrieved during the backward pass. Tensors should not be stored\ndirectly on ctx (though this is not currently enforced for\nbackward compatibility). Instead, tensors should be saved either with\n:func:ctx.save_for_backward if they are intended to be used in\nbackward (equivalently, vjp) or :func:ctx.save_for_forward\nif they are intended to be used for in jvp.

\n", "signature": "(ctx, x, gate, weight, eps=1e-06, group_size=4):", "funcdef": "def"}, {"fullname": "linghe.facade.norm.GroupNormGateFunction.backward", "modulename": "linghe.facade.norm", "qualname": "GroupNormGateFunction.backward", "kind": "function", "doc": "

Define a formula for differentiating the operation with backward mode automatic differentiation.

\n\n

This function is to be overridden by all subclasses.\n(Defining this function is equivalent to defining the vjp function.)

\n\n

It must accept a context :attr:ctx as the first argument, followed by\nas many outputs as the :func:forward returned (None will be passed in\nfor non tensor outputs of the forward function),\nand it should return as many tensors, as there were inputs to\n:func:forward. Each argument is the gradient w.r.t the given output,\nand each returned value should be the gradient w.r.t. the\ncorresponding input. If an input is not a Tensor or is a Tensor not\nrequiring grads, you can just pass None as a gradient for that input.

\n\n

The context can be used to retrieve tensors saved during the forward\npass. It also has an attribute :attr:ctx.needs_input_grad as a tuple\nof booleans representing whether each input needs gradient. E.g.,\n:func:backward will have ctx.needs_input_grad[0] = True if the\nfirst input to :func:forward needs gradient computed w.r.t. the\noutput.

\n", "signature": "(ctx, dy):", "funcdef": "def"}, {"fullname": "linghe.facade.rope", "modulename": "linghe.facade.rope", "kind": "module", "doc": "

Copyright (c) Ant Financial Service Group and its affiliates.

\n"}, {"fullname": "linghe.facade.rope.QkNormHalfRopeFunction", "modulename": "linghe.facade.rope", "qualname": "QkNormHalfRopeFunction", "kind": "class", "doc": "

Base class to create custom autograd.Function.

\n\n

To create a custom autograd.Function, subclass this class and implement\nthe :meth:forward and :meth:backward static methods. Then, to use your custom\nop in the forward pass, call the class method apply. Do not call\n:meth:forward directly.

\n\n

To ensure correctness and best performance, make sure you are calling the\ncorrect methods on ctx and validating your backward function using\n:func:torch.autograd.gradcheck.

\n\n

See :ref:extending-autograd for more details on how to use this class.

\n\n

Examples::

\n\n
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)\n>>> class Exp(Function):\n>>>     @staticmethod\n>>>     def forward(ctx, i):\n>>>         result = i.exp()\n>>>         ctx.save_for_backward(result)\n>>>         return result\n>>>\n>>>     @staticmethod\n>>>     def backward(ctx, grad_output):\n>>>         result, = ctx.saved_tensors\n>>>         return grad_output * result\n>>>\n>>> # Use it by calling the apply method:\n>>> # xdoctest: +SKIP\n>>> output = Exp.apply(input)\n
\n", "bases": "torch.autograd.function.Function"}, {"fullname": "linghe.facade.rope.QkNormHalfRopeFunction.forward", "modulename": "linghe.facade.rope", "qualname": "QkNormHalfRopeFunction.forward", "kind": "function", "doc": "

Define the forward of the custom autograd Function.

\n\n

This function is to be overridden by all subclasses.\nThere are two ways to define forward:

\n\n

Usage 1 (Combined forward and ctx)::

\n\n
@staticmethod\ndef forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:\n    pass\n
\n\n
    \n
  • It must accept a context ctx as the first argument, followed by any\nnumber of arguments (tensors or other types).
  • \n
  • See :ref:combining-forward-context for more details
  • \n
\n\n

Usage 2 (Separate forward and ctx)::

\n\n
@staticmethod\ndef forward(*args: Any, **kwargs: Any) -> Any:\n    pass\n\n@staticmethod\ndef setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:\n    pass\n
\n\n
    \n
  • The forward no longer accepts a ctx argument.
  • \n
  • Instead, you must also override the :meth:torch.autograd.Function.setup_context\nstaticmethod to handle setting up the ctx object.\noutput is the output of the forward, inputs are a Tuple of inputs\nto the forward.
  • \n
  • See :ref:extending-autograd for more details
  • \n
\n\n

The context can be used to store arbitrary data that can be then\nretrieved during the backward pass. Tensors should not be stored\ndirectly on ctx (though this is not currently enforced for\nbackward compatibility). Instead, tensors should be saved either with\n:func:ctx.save_for_backward if they are intended to be used in\nbackward (equivalently, vjp) or :func:ctx.save_for_forward\nif they are intended to be used for in jvp.

\n", "signature": "(ctx, qkv, q_norm_weight, k_norm_weight, freqs, H=32, h=4, eps=1e-06):", "funcdef": "def"}, {"fullname": "linghe.facade.rope.QkNormHalfRopeFunction.backward", "modulename": "linghe.facade.rope", "qualname": "QkNormHalfRopeFunction.backward", "kind": "function", "doc": "

Define a formula for differentiating the operation with backward mode automatic differentiation.

\n\n

This function is to be overridden by all subclasses.\n(Defining this function is equivalent to defining the vjp function.)

\n\n

It must accept a context :attr:ctx as the first argument, followed by\nas many outputs as the :func:forward returned (None will be passed in\nfor non tensor outputs of the forward function),\nand it should return as many tensors, as there were inputs to\n:func:forward. Each argument is the gradient w.r.t the given output,\nand each returned value should be the gradient w.r.t. the\ncorresponding input. If an input is not a Tensor or is a Tensor not\nrequiring grads, you can just pass None as a gradient for that input.

\n\n

The context can be used to retrieve tensors saved during the forward\npass. It also has an attribute :attr:ctx.needs_input_grad as a tuple\nof booleans representing whether each input needs gradient. E.g.,\n:func:backward will have ctx.needs_input_grad[0] = True if the\nfirst input to :func:forward needs gradient computed w.r.t. the\noutput.

\n", "signature": "(ctx, grad_q, grad_k, grad_v):", "funcdef": "def"}, {"fullname": "linghe.facade.transpose", "modulename": "linghe.facade.transpose", "kind": "module", "doc": "

Copyright (c) Ant Financial Service Group and its affiliates.

\n"}, {"fullname": "linghe.facade.transpose.TransposeDim01Function", "modulename": "linghe.facade.transpose", "qualname": "TransposeDim01Function", "kind": "class", "doc": "

Base class to create custom autograd.Function.

\n\n

To create a custom autograd.Function, subclass this class and implement\nthe :meth:forward and :meth:backward static methods. Then, to use your custom\nop in the forward pass, call the class method apply. Do not call\n:meth:forward directly.

\n\n

To ensure correctness and best performance, make sure you are calling the\ncorrect methods on ctx and validating your backward function using\n:func:torch.autograd.gradcheck.

\n\n

See :ref:extending-autograd for more details on how to use this class.

\n\n

Examples::

\n\n
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)\n>>> class Exp(Function):\n>>>     @staticmethod\n>>>     def forward(ctx, i):\n>>>         result = i.exp()\n>>>         ctx.save_for_backward(result)\n>>>         return result\n>>>\n>>>     @staticmethod\n>>>     def backward(ctx, grad_output):\n>>>         result, = ctx.saved_tensors\n>>>         return grad_output * result\n>>>\n>>> # Use it by calling the apply method:\n>>> # xdoctest: +SKIP\n>>> output = Exp.apply(input)\n
\n", "bases": "torch.autograd.function.Function"}, {"fullname": "linghe.facade.transpose.TransposeDim01Function.forward", "modulename": "linghe.facade.transpose", "qualname": "TransposeDim01Function.forward", "kind": "function", "doc": "

Define the forward of the custom autograd Function.

\n\n

This function is to be overridden by all subclasses.\nThere are two ways to define forward:

\n\n

Usage 1 (Combined forward and ctx)::

\n\n
@staticmethod\ndef forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:\n    pass\n
\n\n
    \n
  • It must accept a context ctx as the first argument, followed by any\nnumber of arguments (tensors or other types).
  • \n
  • See :ref:combining-forward-context for more details
  • \n
\n\n

Usage 2 (Separate forward and ctx)::

\n\n
@staticmethod\ndef forward(*args: Any, **kwargs: Any) -> Any:\n    pass\n\n@staticmethod\ndef setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:\n    pass\n
\n\n
    \n
  • The forward no longer accepts a ctx argument.
  • \n
  • Instead, you must also override the :meth:torch.autograd.Function.setup_context\nstaticmethod to handle setting up the ctx object.\noutput is the output of the forward, inputs are a Tuple of inputs\nto the forward.
  • \n
  • See :ref:extending-autograd for more details
  • \n
\n\n

The context can be used to store arbitrary data that can be then\nretrieved during the backward pass. Tensors should not be stored\ndirectly on ctx (though this is not currently enforced for\nbackward compatibility). Instead, tensors should be saved either with\n:func:ctx.save_for_backward if they are intended to be used in\nbackward (equivalently, vjp) or :func:ctx.save_for_forward\nif they are intended to be used for in jvp.

\n", "signature": "(ctx, x):", "funcdef": "def"}, {"fullname": "linghe.facade.transpose.TransposeDim01Function.backward", "modulename": "linghe.facade.transpose", "qualname": "TransposeDim01Function.backward", "kind": "function", "doc": "

Define a formula for differentiating the operation with backward mode automatic differentiation.

\n\n

This function is to be overridden by all subclasses.\n(Defining this function is equivalent to defining the vjp function.)

\n\n

It must accept a context :attr:ctx as the first argument, followed by\nas many outputs as the :func:forward returned (None will be passed in\nfor non tensor outputs of the forward function),\nand it should return as many tensors, as there were inputs to\n:func:forward. Each argument is the gradient w.r.t the given output,\nand each returned value should be the gradient w.r.t. the\ncorresponding input. If an input is not a Tensor or is a Tensor not\nrequiring grads, you can just pass None as a gradient for that input.

\n\n

The context can be used to retrieve tensors saved during the forward\npass. It also has an attribute :attr:ctx.needs_input_grad as a tuple\nof booleans representing whether each input needs gradient. E.g.,\n:func:backward will have ctx.needs_input_grad[0] = True if the\nfirst input to :func:forward needs gradient computed w.r.t. the\noutput.

\n", "signature": "(ctx, grad_output):", "funcdef": "def"}, {"fullname": "linghe.gemm", "modulename": "linghe.gemm", "kind": "module", "doc": "

\n"}, {"fullname": "linghe.gemm.fp32_gemm", "modulename": "linghe.gemm.fp32_gemm", "kind": "module", "doc": "

Copyright (c) Ant Financial Service Group and its affiliates.

\n"}, {"fullname": "linghe.gemm.fp32_gemm.fp32_gemm_kernel", "modulename": "linghe.gemm.fp32_gemm", "qualname": "fp32_gemm_kernel", "kind": "function", "doc": "

\n", "signature": "(\ta_ptr,\tb_ptr,\tc_ptr,\tM,\tN: int,\tK: int,\tBLOCK_SIZE_K: int,\tBLOCK_SIZE_M: int,\tBLOCK_SIZE_N: int):", "funcdef": "def"}, {"fullname": "linghe.gemm.fp32_gemm.triton_fp32_gemm", "modulename": "linghe.gemm.fp32_gemm", "qualname": "triton_fp32_gemm", "kind": "function", "doc": "

\n", "signature": "(a: torch.Tensor, b: torch.Tensor):", "funcdef": "def"}, {"fullname": "linghe.gemm.fp32_gemm.scaled_fp32_gemm_kernel", "modulename": "linghe.gemm.fp32_gemm", "qualname": "scaled_fp32_gemm_kernel", "kind": "function", "doc": "

\n", "signature": "(\ta_ptr,\tb_ptr,\tscale_ptr,\tc_ptr,\tM,\tN: int,\tK: int,\tBLOCK_SIZE_K: int,\tBLOCK_SIZE_M: int,\tBLOCK_SIZE_N: int):", "funcdef": "def"}, {"fullname": "linghe.gemm.fp32_gemm.triton_scaled_fp32_gemm", "modulename": "linghe.gemm.fp32_gemm", "qualname": "triton_scaled_fp32_gemm", "kind": "function", "doc": "

\n", "signature": "(a: torch.Tensor, b: torch.Tensor, scale: torch.Tensor):", "funcdef": "def"}, {"fullname": "linghe.gemm.fp32_gemm.fp32_gemm_for_backward_kernel", "modulename": "linghe.gemm.fp32_gemm", "qualname": "fp32_gemm_for_backward_kernel", "kind": "function", "doc": "

\n", "signature": "(\ta_ptr,\tb_ptr,\tc_ptr,\tM,\tN: int,\tK: int,\tACCUM: int,\tBLOCK_SIZE_K: int,\tBLOCK_SIZE_M: int,\tBLOCK_SIZE_N: int):", "funcdef": "def"}, {"fullname": "linghe.gemm.fp32_gemm.triton_fp32_gemm_for_backward", "modulename": "linghe.gemm.fp32_gemm", "qualname": "triton_fp32_gemm_for_backward", "kind": "function", "doc": "

\n", "signature": "(\ta: torch.Tensor,\tb: torch.Tensor,\tc: Optional[torch.Tensor] = None,\taccum=False):", "funcdef": "def"}, {"fullname": "linghe.gemm.fp32_gemm.fp32_gemm_for_update_kernel", "modulename": "linghe.gemm.fp32_gemm", "qualname": "fp32_gemm_for_update_kernel", "kind": "function", "doc": "

\n", "signature": "(\ta_ptr,\tb_ptr,\tc_ptr,\tM,\tN: int,\tK: int,\tBLOCK_SIZE_K: int,\tBLOCK_SIZE_M: int,\tBLOCK_SIZE_N: int):", "funcdef": "def"}, {"fullname": "linghe.gemm.fp32_gemm.triton_fp32_gemm_for_update", "modulename": "linghe.gemm.fp32_gemm", "qualname": "triton_fp32_gemm_for_update", "kind": "function", "doc": "

\n", "signature": "(a: torch.Tensor, b: torch.Tensor):", "funcdef": "def"}, {"fullname": "linghe.gemm.fp32_gemm.scaled_fp32_gemm_for_update_kernel", "modulename": "linghe.gemm.fp32_gemm", "qualname": "scaled_fp32_gemm_for_update_kernel", "kind": "function", "doc": "

\n", "signature": "(\ta_ptr,\tb_ptr,\tscale_ptr,\tc_ptr,\tM,\tN: int,\tK: int,\tBLOCK_SIZE_K: int,\tBLOCK_SIZE_M: int,\tBLOCK_SIZE_N: int):", "funcdef": "def"}, {"fullname": "linghe.gemm.fp32_gemm.triton_scaled_fp32_gemm_for_update", "modulename": "linghe.gemm.fp32_gemm", "qualname": "triton_scaled_fp32_gemm_for_update", "kind": "function", "doc": "

\n", "signature": "(a: torch.Tensor, b: torch.Tensor, scale: torch.Tensor):", "funcdef": "def"}, {"fullname": "linghe.quant", "modulename": "linghe.quant", "kind": "module", "doc": "

\n"}, {"fullname": "linghe.quant.block", "modulename": "linghe.quant.block", "kind": "module", "doc": "

\n"}, {"fullname": "linghe.quant.block.block", "modulename": "linghe.quant.block.block", "kind": "module", "doc": "

Copyright (c) Ant Financial Service Group and its affiliates.

\n"}, {"fullname": "linghe.quant.block.block.block_quant_kernel", "modulename": "linghe.quant.block.block", "qualname": "block_quant_kernel", "kind": "function", "doc": "

\n", "signature": "(x_ptr, y_ptr, s_ptr, M, N, BLOCK_SIZE: int, ROUND: int):", "funcdef": "def"}, {"fullname": "linghe.quant.block.block.block_quant", "modulename": "linghe.quant.block.block", "qualname": "block_quant", "kind": "function", "doc": "

\n", "signature": "(x, block_size=128, round_scale=False):", "funcdef": "def"}, {"fullname": "linghe.quant.block.group", "modulename": "linghe.quant.block.group", "kind": "module", "doc": "

Copyright (c) Ant Financial Service Group and its affiliates.

\n"}, {"fullname": "linghe.quant.block.group.group_quant_kernel", "modulename": "linghe.quant.block.group", "qualname": "group_quant_kernel", "kind": "function", "doc": "

\n", "signature": "(x_ptr, y_ptr, s_ptr, N, BLOCK_SIZE: int, K: int, ROUND: int):", "funcdef": "def"}, {"fullname": "linghe.quant.block.group.triton_group_quant", "modulename": "linghe.quant.block.group", "qualname": "triton_group_quant", "kind": "function", "doc": "

\n", "signature": "(x, dtype=torch.float8_e4m3fn, group_size=128, round_scale=False):", "funcdef": "def"}, {"fullname": "linghe.quant.block.group.persist_group_quant_kernel", "modulename": "linghe.quant.block.group", "qualname": "persist_group_quant_kernel", "kind": "function", "doc": "

\n", "signature": "(x_ptr, y_ptr, s_ptr, N, BLOCK_SIZE: int, B: int, K: int, ROUND: int):", "funcdef": "def"}, {"fullname": "linghe.quant.block.group.triton_persist_group_quant", "modulename": "linghe.quant.block.group", "qualname": "triton_persist_group_quant", "kind": "function", "doc": "

\n", "signature": "(x, dtype=torch.float8_e4m3fn, group_size=128, round_scale=False):", "funcdef": "def"}, {"fullname": "linghe.quant.channel", "modulename": "linghe.quant.channel", "kind": "module", "doc": "

\n"}, {"fullname": "linghe.quant.channel.channel", "modulename": "linghe.quant.channel.channel", "kind": "module", "doc": "

Copyright (c) Ant Financial Service Group and its affiliates.

\n"}, {"fullname": "linghe.quant.channel.channel.row_quant_kernel", "modulename": "linghe.quant.channel.channel", "qualname": "row_quant_kernel", "kind": "function", "doc": "

\n", "signature": "(x_ptr, q_ptr, s_ptr, M, N, BLOCK_SIZE: int, ROUND: int):", "funcdef": "def"}, {"fullname": "linghe.quant.channel.channel.triton_row_quant", "modulename": "linghe.quant.channel.channel", "qualname": "triton_row_quant", "kind": "function", "doc": "

\n", "signature": "(x, round_scale=False):", "funcdef": "def"}, {"fullname": "linghe.quant.channel.channel.deprecated_tokenwise_row_quant_kernel", "modulename": "linghe.quant.channel.channel", "qualname": "deprecated_tokenwise_row_quant_kernel", "kind": "function", "doc": "

\n", "signature": "(x_ptr, out_ptr, scale_ptr, M, T: int, N: int, ROUND: int):", "funcdef": "def"}, {"fullname": "linghe.quant.channel.channel.triton_deprecated_tokenwise_row_quant", "modulename": "linghe.quant.channel.channel", "qualname": "triton_deprecated_tokenwise_row_quant", "kind": "function", "doc": "

\n", "signature": "(x, out=None, scale=None, round_scale=False):", "funcdef": "def"}, {"fullname": "linghe.quant.channel.channel.tokenwise_row_quant_kernel", "modulename": "linghe.quant.channel.channel", "qualname": "tokenwise_row_quant_kernel", "kind": "function", "doc": "

\n", "signature": "(x_ptr, out_ptr, scale_ptr, N: int, ROUND: int):", "funcdef": "def"}, {"fullname": "linghe.quant.channel.channel.triton_tokenwise_row_quant", "modulename": "linghe.quant.channel.channel", "qualname": "triton_tokenwise_row_quant", "kind": "function", "doc": "

\n", "signature": "(x, out=None, scale=None, round_scale=False):", "funcdef": "def"}, {"fullname": "linghe.quant.channel.channel.transpose_row_quant_kernel", "modulename": "linghe.quant.channel.channel", "qualname": "transpose_row_quant_kernel", "kind": "function", "doc": "

\n", "signature": "(x_ptr, q_ptr, s_ptr, M, N, H: int, W: int, ROUND: int):", "funcdef": "def"}, {"fullname": "linghe.quant.channel.channel.triton_transpose_row_quant", "modulename": "linghe.quant.channel.channel", "qualname": "triton_transpose_row_quant", "kind": "function", "doc": "

\n", "signature": "(x, side=0, round_scale=False):", "funcdef": "def"}, {"fullname": "linghe.quant.channel.channel.triton_channel_quant_nt", "modulename": "linghe.quant.channel.channel", "qualname": "triton_channel_quant_nt", "kind": "function", "doc": "

\n", "signature": "(x, w):", "funcdef": "def"}, {"fullname": "linghe.quant.channel.channel.triton_channel_quant_nn", "modulename": "linghe.quant.channel.channel", "qualname": "triton_channel_quant_nn", "kind": "function", "doc": "

\n", "signature": "(y, w):", "funcdef": "def"}, {"fullname": "linghe.quant.channel.channel.triton_channel_quant_tn", "modulename": "linghe.quant.channel.channel", "qualname": "triton_channel_quant_tn", "kind": "function", "doc": "

\n", "signature": "(y, x):", "funcdef": "def"}, {"fullname": "linghe.quant.channel.channel.channel_quant_forward", "modulename": "linghe.quant.channel.channel", "qualname": "channel_quant_forward", "kind": "function", "doc": "

\n", "signature": "(x, w):", "funcdef": "def"}, {"fullname": "linghe.quant.channel.channel.channel_quant_backward", "modulename": "linghe.quant.channel.channel", "qualname": "channel_quant_backward", "kind": "function", "doc": "

\n", "signature": "(y, w):", "funcdef": "def"}, {"fullname": "linghe.quant.channel.channel.channel_quant_update", "modulename": "linghe.quant.channel.channel", "qualname": "channel_quant_update", "kind": "function", "doc": "

\n", "signature": "(y, x):", "funcdef": "def"}, {"fullname": "linghe.quant.channel.channel.fp8_channel_f_and_b", "modulename": "linghe.quant.channel.channel", "qualname": "fp8_channel_f_and_b", "kind": "function", "doc": "

\n", "signature": "(x, w, y):", "funcdef": "def"}, {"fullname": "linghe.utils", "modulename": "linghe.utils", "kind": "module", "doc": "

\n"}, {"fullname": "linghe.utils.add", "modulename": "linghe.utils.add", "kind": "module", "doc": "

Copyright (c) Ant Financial Service Group and its affiliates.

\n"}, {"fullname": "linghe.utils.add.inplace_add_kernel", "modulename": "linghe.utils.add", "qualname": "inplace_add_kernel", "kind": "function", "doc": "

\n", "signature": "(x_ptr, y_ptr, M, N, H: int, W: int, EVEN: int, ACCUM: int):", "funcdef": "def"}, {"fullname": "linghe.utils.add.triton_inplace_add", "modulename": "linghe.utils.add", "qualname": "triton_inplace_add", "kind": "function", "doc": "

inplace add y to x\nArgs:\n x: Tensor\n y: Tensor\n accum: whether accum y to x

\n\n

Returns: x += y if accum=True else x.copy_(y)

\n", "signature": "(x: torch.Tensor, y: torch.Tensor, accum: bool = True):", "funcdef": "def"}, {"fullname": "linghe.utils.dot", "modulename": "linghe.utils.dot", "kind": "module", "doc": "

Copyright (c) Ant Financial Service Group and its affiliates.

\n"}, {"fullname": "linghe.utils.dot.dot_kernel", "modulename": "linghe.utils.dot", "qualname": "dot_kernel", "kind": "function", "doc": "

\n", "signature": "(x_ptr, y_ptr, sum_ptr, M, N, H: int, W: int):", "funcdef": "def"}, {"fullname": "linghe.utils.dot.triton_dot", "modulename": "linghe.utils.dot", "qualname": "triton_dot", "kind": "function", "doc": "

\n", "signature": "(x, y):", "funcdef": "def"}, {"fullname": "linghe.utils.dot.mix_precise_dot_kernel", "modulename": "linghe.utils.dot", "qualname": "mix_precise_dot_kernel", "kind": "function", "doc": "

\n", "signature": "(\tx_ptr,\tq_ptr,\tsum_ptr,\tsmooth_scale_ptr,\tquant_scale_ptr,\tM,\tN,\tH: int,\tW: int):", "funcdef": "def"}, {"fullname": "linghe.utils.dot.triton_mix_precise_dot", "modulename": "linghe.utils.dot", "qualname": "triton_mix_precise_dot", "kind": "function", "doc": "

\n", "signature": "(x, q, smooth_scale, quant_scale, reverse=False):", "funcdef": "def"}, {"fullname": "linghe.utils.gather", "modulename": "linghe.utils.gather", "kind": "module", "doc": "

Copyright (c) Ant Financial Service Group and its affiliates.

\n"}, {"fullname": "linghe.utils.gather.block_count_kernel", "modulename": "linghe.utils.gather", "qualname": "block_count_kernel", "kind": "function", "doc": "

\n", "signature": "(map_ptr, count_ptr, M, B, T: int, b: int, E: int):", "funcdef": "def"}, {"fullname": "linghe.utils.gather.make_row_id_map_kernel", "modulename": "linghe.utils.gather", "qualname": "make_row_id_map_kernel", "kind": "function", "doc": "

\n", "signature": "(map_ptr, count_ptr, output_ptr, M, B, P, T: int, b: int, E: int):", "funcdef": "def"}, {"fullname": "linghe.utils.gather.triton_make_row_id_map", "modulename": "linghe.utils.gather", "qualname": "triton_make_row_id_map", "kind": "function", "doc": "

\n", "signature": "(routing_map: torch.Tensor, multiple_of: int = 1):", "funcdef": "def"}, {"fullname": "linghe.utils.gather.make_row_id_map_and_indices_kernel", "modulename": "linghe.utils.gather", "qualname": "make_row_id_map_and_indices_kernel", "kind": "function", "doc": "

\n", "signature": "(\tmap_ptr,\tcount_ptr,\trow_map_ptr,\trow_indices_ptr,\tM,\tB,\tP,\tT: int,\tb: int,\tE: int):", "funcdef": "def"}, {"fullname": "linghe.utils.gather.triton_make_row_id_map_and_indices", "modulename": "linghe.utils.gather", "qualname": "triton_make_row_id_map_and_indices", "kind": "function", "doc": "

\n", "signature": "(routing_map: torch.Tensor, num_out_tokens: int, multiple_of: int = 1):", "funcdef": "def"}, {"fullname": "linghe.utils.gather.index_select_kernel", "modulename": "linghe.utils.gather", "qualname": "index_select_kernel", "kind": "function", "doc": "

\n", "signature": "(\tx_ptr,\tout_ptr,\tscale_ptr,\tscale_out_ptr,\tindex_ptr,\tM,\tT,\tN: int,\tSCALE: int):", "funcdef": "def"}, {"fullname": "linghe.utils.gather.triton_index_select", "modulename": "linghe.utils.gather", "qualname": "triton_index_select", "kind": "function", "doc": "

\n", "signature": "(x, indices, scale=None, out=None, scale_out=None):", "funcdef": "def"}, {"fullname": "linghe.utils.gather.permute_with_mask_map_kernel", "modulename": "linghe.utils.gather", "qualname": "permute_with_mask_map_kernel", "kind": "function", "doc": "

\n", "signature": "(\tdata_ptr,\tscale_ptr,\tprobs_ptr,\tmask_map_ptr,\toutput_data_ptr,\toutput_scale_ptr,\toutput_probs_ptr,\tnum_experts: int,\tN: int,\ths: int,\tSCALE: int,\tPROB: int):", "funcdef": "def"}, {"fullname": "linghe.utils.gather.fill_padded_token_with_zero_kernel", "modulename": "linghe.utils.gather", "qualname": "fill_padded_token_with_zero_kernel", "kind": "function", "doc": "

\n", "signature": "(\tdata_ptr,\tscale_ptr,\tprobs_ptr,\tmax_indices_ptr,\ttoken_per_expert_ptr,\tN: int,\ths: int,\tSCALE: int,\tPROB: int):", "funcdef": "def"}, {"fullname": "linghe.utils.gather.triton_permute_with_mask_map", "modulename": "linghe.utils.gather", "qualname": "triton_permute_with_mask_map", "kind": "function", "doc": "

\n", "signature": "(\tinp: torch.Tensor,\tscale: torch.Tensor,\tprobs: torch.Tensor,\trow_id_map: torch.Tensor,\tnum_out_tokens: int,\tcontiguous: bool = True,\ttokens_per_expert: Optional[torch.Tensor] = None):", "funcdef": "def"}, {"fullname": "linghe.utils.gather.batch_smooth_transpose_smooth_permute_kernel", "modulename": "linghe.utils.gather", "qualname": "batch_smooth_transpose_smooth_permute_kernel", "kind": "function", "doc": "

\n", "signature": "(\tx_ptr,\tscale_ptr,\toss_ptr,\tss_ptr,\tindex_ptr,\tcount_ptr,\taccum_ptr,\tq_ptr,\tqs_ptr,\tN: int,\tE: int,\tH: int,\tW: int,\tSMOOTHED: int,\tROUND: int):", "funcdef": "def"}, {"fullname": "linghe.utils.gather.triton_batch_transpose_smooth_permute_with_indices", "modulename": "linghe.utils.gather", "qualname": "triton_batch_transpose_smooth_permute_with_indices", "kind": "function", "doc": "

\n", "signature": "(\tx,\tscale,\torg_smooth_scale,\tsmooth_scales,\tindices,\ttoken_count_per_expert,\tsplits,\tx_q=None,\tx_scale=None,\tround_scale=False):", "funcdef": "def"}, {"fullname": "linghe.utils.gather.smooth_weighted_permute_with_indices_kernel", "modulename": "linghe.utils.gather", "qualname": "smooth_weighted_permute_with_indices_kernel", "kind": "function", "doc": "

\n", "signature": "(\tgrads_ptr,\ttokens_ptr,\tq_ptr,\tss_ptr,\tqs_ptr,\tcount_ptr,\taccum_ptr,\tindex_ptr,\tsum_ptr,\tM,\tN: int,\tREVERSE: int,\tROUND: int):", "funcdef": "def"}, {"fullname": "linghe.utils.gather.triton_smooth_weighted_permute_with_indices", "modulename": "linghe.utils.gather", "qualname": "triton_smooth_weighted_permute_with_indices", "kind": "function", "doc": "

\n", "signature": "(\tgrads,\ttokens,\tsmooth_scales,\ttoken_count_per_expert,\tindices,\tx_q=None,\tx_scale=None,\tx_sum=None,\treverse=False,\tround_scale=False):", "funcdef": "def"}, {"fullname": "linghe.utils.gather.smooth_permute_with_indices_kernel", "modulename": "linghe.utils.gather", "qualname": "smooth_permute_with_indices_kernel", "kind": "function", "doc": "

\n", "signature": "(\tgrads_data_ptr,\tgrads_scale_ptr,\tq_ptr,\tss_ptr,\tqs_ptr,\tcount_ptr,\taccum_ptr,\tindex_ptr,\tN: int,\ths: int,\tREVERSE: int,\tROUND: int,\tGROUP: int):", "funcdef": "def"}, {"fullname": "linghe.utils.gather.triton_smooth_permute_with_indices", "modulename": "linghe.utils.gather", "qualname": "triton_smooth_permute_with_indices", "kind": "function", "doc": "

\n", "signature": "(\tgrad_data,\tgrad_scale,\tsmooth_scales,\ttoken_count_per_expert,\tindices,\tx_q=None,\tx_scale=None,\treverse=False,\tround_scale=False):", "funcdef": "def"}, {"fullname": "linghe.utils.gather.smooth_permute_with_mask_map_kernel", "modulename": "linghe.utils.gather", "qualname": "smooth_permute_with_mask_map_kernel", "kind": "function", "doc": "

\n", "signature": "(\tgrads_data_ptr,\tquant_data_ptr,\tmask_map_ptr,\tgrads_scale_ptr,\tsmooth_scale_ptr,\tquant_scale_ptr,\tM,\tT,\tN: int,\ths: int,\tREVERSE: int,\tROUND: int):", "funcdef": "def"}, {"fullname": "linghe.utils.gather.triton_smooth_permute_with_mask_map", "modulename": "linghe.utils.gather", "qualname": "triton_smooth_permute_with_mask_map", "kind": "function", "doc": "

\n", "signature": "(\tinp: torch.Tensor,\trow_id_map: torch.Tensor,\tscale: torch.Tensor,\tnum_tokens: int,\tnum_experts: int,\tnum_out_tokens: int,\thidden_size: int,\tsmooth_scales: torch.Tensor,\treverse=True,\tround_scale=False):", "funcdef": "def"}, {"fullname": "linghe.utils.gather.deprecated_smooth_permute_with_mask_map_kernel", "modulename": "linghe.utils.gather", "qualname": "deprecated_smooth_permute_with_mask_map_kernel", "kind": "function", "doc": "

\n", "signature": "(\tgrads_data_ptr,\tquant_data_ptr,\tmask_map_ptr,\tsmooth_scale_ptr,\tquant_scale_ptr,\tM,\tT,\tN: int,\tREVERSE: int,\tROUND: int):", "funcdef": "def"}, {"fullname": "linghe.utils.gather.triton_deprecated_smooth_permute_with_mask_map", "modulename": "linghe.utils.gather", "qualname": "triton_deprecated_smooth_permute_with_mask_map", "kind": "function", "doc": "

\n", "signature": "(\tinp: torch.Tensor,\trow_id_map: torch.Tensor,\tnum_tokens: int,\tnum_experts: int,\tnum_out_tokens: int,\thidden_size: int,\tsmooth_scales: torch.Tensor,\treverse=True,\tround_scale=False):", "funcdef": "def"}, {"fullname": "linghe.utils.loss", "modulename": "linghe.utils.loss", "kind": "module", "doc": "

Copyright (c) Ant Financial Service Group and its affiliates.

\n"}, {"fullname": "linghe.utils.loss.softmax_cross_entropy_forward_kernel", "modulename": "linghe.utils.loss", "qualname": "softmax_cross_entropy_forward_kernel", "kind": "function", "doc": "

\n", "signature": "(\tlogit_ptr,\tlabel_ptr,\tloss_ptr,\tsum_exp_ptr,\tmax_logit_ptr,\tN,\tB: int):", "funcdef": "def"}, {"fullname": "linghe.utils.loss.triton_softmax_cross_entropy_forward", "modulename": "linghe.utils.loss", "qualname": "triton_softmax_cross_entropy_forward", "kind": "function", "doc": "

\n", "signature": "(logits, labels):", "funcdef": "def"}, {"fullname": "linghe.utils.loss.softmax_cross_entropy_backward_kernel", "modulename": "linghe.utils.loss", "qualname": "softmax_cross_entropy_backward_kernel", "kind": "function", "doc": "

\n", "signature": "(\tlogit_ptr,\tlabel_ptr,\tsum_exp_ptr,\tmax_logit_ptr,\tinput_grad_ptr,\toutput_grad_ptr,\tN,\tB: int):", "funcdef": "def"}, {"fullname": "linghe.utils.loss.triton_softmax_cross_entropy_backward", "modulename": "linghe.utils.loss", "qualname": "triton_softmax_cross_entropy_backward", "kind": "function", "doc": "

\n", "signature": "(logits, labels, sum_exp, max_logit, input_grad, output_grad=None):", "funcdef": "def"}, {"fullname": "linghe.utils.norm", "modulename": "linghe.utils.norm", "kind": "module", "doc": "

\n"}, {"fullname": "linghe.utils.norm.rms_norm_forward_kernel", "modulename": "linghe.utils.norm", "qualname": "rms_norm_forward_kernel", "kind": "function", "doc": "

\n", "signature": "(x_ptr, weight_ptr, out_ptr, eps, M, T, N: int, W: int):", "funcdef": "def"}, {"fullname": "linghe.utils.norm.triton_rms_norm_forward", "modulename": "linghe.utils.norm", "qualname": "triton_rms_norm_forward", "kind": "function", "doc": "

\n", "signature": "(x, weight, eps=1e-06, out=None):", "funcdef": "def"}, {"fullname": "linghe.utils.norm.rms_norm_backward_kernel", "modulename": "linghe.utils.norm", "qualname": "rms_norm_backward_kernel", "kind": "function", "doc": "

\n", "signature": "(\tgrad_output_ptr,\tx_ptr,\tw_ptr,\tdx_ptr,\tdw_ptr,\teps,\tM,\tT,\tN: int,\tW: int):", "funcdef": "def"}, {"fullname": "linghe.utils.norm.triton_rms_norm_backward", "modulename": "linghe.utils.norm", "qualname": "triton_rms_norm_backward", "kind": "function", "doc": "

\n", "signature": "(grad_output, x, w, eps=1e-06):", "funcdef": "def"}, {"fullname": "linghe.utils.norm.rms_norm_and_block_quant_forward_kernel", "modulename": "linghe.utils.norm", "qualname": "rms_norm_and_block_quant_forward_kernel", "kind": "function", "doc": "

\n", "signature": "(\tx_ptr,\tweight_ptr,\tout_ptr,\tscale_ptr,\ttranspose_output_ptr,\ttranspose_scale_ptr,\trms_ptr,\teps,\tM,\tT: int,\tN: int,\tnb: int,\tW: int,\tH: int,\tROUND: int):", "funcdef": "def"}, {"fullname": "linghe.utils.norm.rms_norm_and_block_quant_forward_n_kernel", "modulename": "linghe.utils.norm", "qualname": "rms_norm_and_block_quant_forward_n_kernel", "kind": "function", "doc": "

\n", "signature": "(\tx_ptr,\tweight_ptr,\tout_ptr,\tscale_ptr,\trms_ptr,\teps,\tM: int,\tT: int,\tN: int,\tnb: int,\tW: int,\tROUND: int):", "funcdef": "def"}, {"fullname": "linghe.utils.norm.rms_norm_and_block_quant_forward_t_kernel", "modulename": "linghe.utils.norm", "qualname": "rms_norm_and_block_quant_forward_t_kernel", "kind": "function", "doc": "

\n", "signature": "(\tx_ptr,\tweight_ptr,\ttranspose_output_ptr,\ttranspose_scale_ptr,\trms_ptr,\tM,\tN,\tW: int,\tROUND: int):", "funcdef": "def"}, {"fullname": "linghe.utils.norm.triton_rms_norm_and_block_quant_forward", "modulename": "linghe.utils.norm", "qualname": "triton_rms_norm_and_block_quant_forward", "kind": "function", "doc": "

Fused RMSNorm forward and block quantization.\nArgs:\n x: Input tensor, shape [M, N]\n weight: RMSNorm weight, shape [N]\n eps: epsilon value for L2 normalization.\n out: output of quantization data\n scale: output of quantization scale.\n rms: output of rms\n round_scale: Set whether to force power of 2 scales.\n output_mode: one of {0, 1, 2}.\n 0: only output non-transpose tensor\n 1: only output transposed tensor\n 2: return both\nReturns:\n out: quantization data\n scale: quantization scale\n rms: Reciprocal of the root mean square of the input calculated over the last dimension.\n transpose_output: quantization data of transposed gradient\n transpose_scale: quantization scale of transposed gradient

\n", "signature": "(\tx: torch.Tensor,\tweight: torch.Tensor,\teps: float = 1e-06,\tout: Optional[torch.Tensor] = None,\tscale: Optional[torch.Tensor] = None,\trms: Optional[torch.Tensor] = None,\tround_scale: bool = False,\toutput_mode: int = 2):", "funcdef": "def"}, {"fullname": "linghe.utils.norm.group_norm_gate_forward_kernel", "modulename": "linghe.utils.norm", "qualname": "group_norm_gate_forward_kernel", "kind": "function", "doc": "

\n", "signature": "(\tx_ptr,\tgate_ptr,\tweight_ptr,\tout_ptr,\teps,\tbs,\tlength,\tDIM: int,\tD: int,\tGROUP_SIZE: int):", "funcdef": "def"}, {"fullname": "linghe.utils.norm.triton_group_norm_gate_forward", "modulename": "linghe.utils.norm", "qualname": "triton_group_norm_gate_forward", "kind": "function", "doc": "

norm and gate in linear attention\nArgs:\n x:\n gate:\n weight:\n eps:\n group_size:

\n\n

Returns:

\n", "signature": "(x: torch.Tensor, gate, weight, eps=1e-06, group_size=4):", "funcdef": "def"}, {"fullname": "linghe.utils.norm.group_rms_gate_backward_kernel", "modulename": "linghe.utils.norm", "qualname": "group_rms_gate_backward_kernel", "kind": "function", "doc": "

\n", "signature": "(\tgrad_output_ptr,\tx_ptr,\tgate_ptr,\tw_ptr,\tdx_ptr,\tdg_ptr,\tdw_ptr,\teps,\tbs,\tlength,\tDIM: int,\tD: int,\tGROUP_SIZE: int,\tT: int):", "funcdef": "def"}, {"fullname": "linghe.utils.norm.triton_group_norm_gate_backward", "modulename": "linghe.utils.norm", "qualname": "triton_group_norm_gate_backward", "kind": "function", "doc": "

\n", "signature": "(grad_output, x, gate, weight, eps=1e-06, group_size=4):", "funcdef": "def"}, {"fullname": "linghe.utils.rearange", "modulename": "linghe.utils.rearange", "kind": "module", "doc": "

Copyright (c) Ant Financial Service Group and its affiliates.

\n"}, {"fullname": "linghe.utils.rearange.split_and_cat_kernel", "modulename": "linghe.utils.rearange", "qualname": "split_and_cat_kernel", "kind": "function", "doc": "

\n", "signature": "(\tx_ptr,\ty_ptr,\tscale_ptr,\tscale_output_ptr,\tcount_ptr,\taccum_ptr,\trev_accum_ptr,\tindex_ptr,\tM,\tN: int,\tSCALE: int,\tK: int):", "funcdef": "def"}, {"fullname": "linghe.utils.rearange.triton_split_and_cat", "modulename": "linghe.utils.rearange", "qualname": "triton_split_and_cat", "kind": "function", "doc": "

\n", "signature": "(x, counts, indices, scales=None):", "funcdef": "def"}, {"fullname": "linghe.utils.reduce", "modulename": "linghe.utils.reduce", "kind": "module", "doc": "

Copyright (c) Ant Financial Service Group and its affiliates.

\n"}, {"fullname": "linghe.utils.reduce.abs_max_kernel", "modulename": "linghe.utils.reduce", "qualname": "abs_max_kernel", "kind": "function", "doc": "

\n", "signature": "(\tx_ptr,\tscale_ptr,\tsmooth_scale_ptr,\toutput_ptr,\tmin_value,\tM,\tN,\tH: int,\tW: int,\tEVEN: int,\tQUANTIZED: int):", "funcdef": "def"}, {"fullname": "linghe.utils.reduce.triton_abs_max", "modulename": "linghe.utils.reduce", "qualname": "triton_abs_max", "kind": "function", "doc": "

\n", "signature": "(x, scale=None, smooth_scale=None, min_value=1e-30, axis=0):", "funcdef": "def"}, {"fullname": "linghe.utils.reduce.batch_count_zero_kernel", "modulename": "linghe.utils.reduce", "qualname": "batch_count_zero_kernel", "kind": "function", "doc": "

\n", "signature": "(input_ptrs, size_ptr, count_ptr, B: int):", "funcdef": "def"}, {"fullname": "linghe.utils.reduce.triton_batch_count_zero", "modulename": "linghe.utils.reduce", "qualname": "triton_batch_count_zero", "kind": "function", "doc": "

\n", "signature": "(xs):", "funcdef": "def"}, {"fullname": "linghe.utils.reduce.batch_sum_with_ord_kernel", "modulename": "linghe.utils.reduce", "qualname": "batch_sum_with_ord_kernel", "kind": "function", "doc": "

\n", "signature": "(input_ptrs, size_ptr, count_ptr, B: int, ORD: int):", "funcdef": "def"}, {"fullname": "linghe.utils.reduce.triton_batch_sum_with_ord", "modulename": "linghe.utils.reduce", "qualname": "triton_batch_sum_with_ord", "kind": "function", "doc": "

\n", "signature": "(xs, ord=2):", "funcdef": "def"}, {"fullname": "linghe.utils.rope", "modulename": "linghe.utils.rope", "kind": "module", "doc": "

Copyright (c) Ant Financial Service Group and its affiliates.

\n"}, {"fullname": "linghe.utils.rope.half_rope_forward_kernel", "modulename": "linghe.utils.rope", "qualname": "half_rope_forward_kernel", "kind": "function", "doc": "

\n", "signature": "(\tq_ptr,\tk_ptr,\tfreqs_ptr,\tqo_ptr,\tko_ptr,\tB,\tq_stride,\tk_stride,\tH: int,\th: int,\tD: int,\td: int):", "funcdef": "def"}, {"fullname": "linghe.utils.rope.triton_half_rope_forward", "modulename": "linghe.utils.rope", "qualname": "triton_half_rope_forward", "kind": "function", "doc": "

\n", "signature": "(q, k, freqs):", "funcdef": "def"}, {"fullname": "linghe.utils.rope.half_rope_backward_kernel", "modulename": "linghe.utils.rope", "qualname": "half_rope_backward_kernel", "kind": "function", "doc": "

\n", "signature": "(q_ptr, k_ptr, freqs_ptr, B, H: int, h: int, D: int, d: int):", "funcdef": "def"}, {"fullname": "linghe.utils.rope.triton_half_rope_backward", "modulename": "linghe.utils.rope", "qualname": "triton_half_rope_backward", "kind": "function", "doc": "

\n", "signature": "(q_grad, k_grad, freqs, inplace=False):", "funcdef": "def"}, {"fullname": "linghe.utils.rope.qk_norm_and_half_rope_forward_kernel", "modulename": "linghe.utils.rope", "qualname": "qk_norm_and_half_rope_forward_kernel", "kind": "function", "doc": "

\n", "signature": "(\tqkv_ptr,\tq_norm_weight_ptr,\tk_norm_weight_ptr,\tfreqs_ptr,\tqo_ptr,\tko_ptr,\tvo_ptr,\tB,\tstride,\teps,\tH: int,\th: int,\tD: int,\td: int,\tinterleave: int):", "funcdef": "def"}, {"fullname": "linghe.utils.rope.triton_qk_norm_and_half_rope_forward", "modulename": "linghe.utils.rope", "qualname": "triton_qk_norm_and_half_rope_forward", "kind": "function", "doc": "

\n", "signature": "(\tqkv,\tq_norm_weight,\tk_norm_weight,\tfreqs,\tH=32,\th=4,\teps=1e-06,\tinterleave=True,\ttranspose=False):", "funcdef": "def"}, {"fullname": "linghe.utils.rope.qk_norm_and_half_rope_backward_kernel", "modulename": "linghe.utils.rope", "qualname": "qk_norm_and_half_rope_backward_kernel", "kind": "function", "doc": "

\n", "signature": "(\tgq_ptr,\tgk_ptr,\tgv_ptr,\tqkv_ptr,\tq_norm_weight_ptr,\tk_norm_weight_ptr,\tfreqs_ptr,\tdqkv_ptr,\tdqw_ptr,\tdkw_ptr,\tB,\tstride,\teps,\tH: int,\th: int,\tD: int,\td: int,\tinterleave: int):", "funcdef": "def"}, {"fullname": "linghe.utils.rope.triton_qk_norm_and_half_rope_backward", "modulename": "linghe.utils.rope", "qualname": "triton_qk_norm_and_half_rope_backward", "kind": "function", "doc": "

\n", "signature": "(\tgq,\tgk,\tgv,\tqkv,\tq_norm_weight,\tk_norm_weight,\tfreqs,\teps=1e-06,\ttranspose=False,\tinterleave=True):", "funcdef": "def"}, {"fullname": "linghe.utils.scatter", "modulename": "linghe.utils.scatter", "kind": "module", "doc": "

Copyright (c) Ant Financial Service Group and its affiliates.

\n"}, {"fullname": "linghe.utils.scatter.aligned_scatter_add_kernel", "modulename": "linghe.utils.scatter", "qualname": "aligned_scatter_add_kernel", "kind": "function", "doc": "

\n", "signature": "(\tx_ptr,\to_ptr,\tindices_ptr,\tweights_ptr,\tM,\tN: int,\tK: int,\tSCALE: int):", "funcdef": "def"}, {"fullname": "linghe.utils.scatter.triton_aligned_scatter_add", "modulename": "linghe.utils.scatter", "qualname": "triton_aligned_scatter_add", "kind": "function", "doc": "

\n", "signature": "(x, outputs, indices, weights=None):", "funcdef": "def"}, {"fullname": "linghe.utils.scatter.scatter_add_kernel", "modulename": "linghe.utils.scatter", "qualname": "scatter_add_kernel", "kind": "function", "doc": "

\n", "signature": "(x_ptr, o_ptr, indices_ptr, M, T, N: int):", "funcdef": "def"}, {"fullname": "linghe.utils.scatter.fp32_to_bf16_kernel", "modulename": "linghe.utils.scatter", "qualname": "fp32_to_bf16_kernel", "kind": "function", "doc": "

\n", "signature": "(x_ptr, o_ptr, M, T, N: int):", "funcdef": "def"}, {"fullname": "linghe.utils.scatter.triton_scatter_add", "modulename": "linghe.utils.scatter", "qualname": "triton_scatter_add", "kind": "function", "doc": "

\n", "signature": "(x, outputs, indices):", "funcdef": "def"}, {"fullname": "linghe.utils.scatter.unpermute_with_mask_map_kernel", "modulename": "linghe.utils.scatter", "qualname": "unpermute_with_mask_map_kernel", "kind": "function", "doc": "

\n", "signature": "(\tgrads_ptr,\tprobs_ptr,\tmask_map_ptr,\toutput_ptr,\toutput_probs_ptr,\tnum_experts: int,\tN: int,\tPROB: int):", "funcdef": "def"}, {"fullname": "linghe.utils.scatter.triton_unpermute_with_mask_map", "modulename": "linghe.utils.scatter", "qualname": "triton_unpermute_with_mask_map", "kind": "function", "doc": "

\n", "signature": "(grad: torch.Tensor, row_id_map: torch.Tensor, probs: torch.Tensor):", "funcdef": "def"}, {"fullname": "linghe.utils.silu", "modulename": "linghe.utils.silu", "kind": "module", "doc": "

Copyright (c) Ant Financial Service Group and its affiliates.

\n"}, {"fullname": "linghe.utils.silu.silu_and_block_quant_forward_kernel", "modulename": "linghe.utils.silu", "qualname": "silu_and_block_quant_forward_kernel", "kind": "function", "doc": "

\n", "signature": "(\tx_ptr,\tout_ptr,\tscale_ptr,\ttranspose_output_ptr,\ttranspose_scale_ptr,\tM,\tn: int,\tROUND: int,\tOUTPUT_MODE: int):", "funcdef": "def"}, {"fullname": "linghe.utils.silu.triton_silu_and_block_quant_forward", "modulename": "linghe.utils.silu", "qualname": "triton_silu_and_block_quant_forward", "kind": "function", "doc": "

\n", "signature": "(x, out=None, scale=None, round_scale=False, output_mode=2):", "funcdef": "def"}, {"fullname": "linghe.utils.silu.silu_and_block_quant_backward_kernel", "modulename": "linghe.utils.silu", "qualname": "silu_and_block_quant_backward_kernel", "kind": "function", "doc": "

\n", "signature": "(\tg_ptr,\tx_ptr,\tdx_ptr,\tdx_scale_ptr,\ttranspose_dx_ptr,\ttranspose_dx_scale_ptr,\tM,\tn: int,\tROUND: int):", "funcdef": "def"}, {"fullname": "linghe.utils.silu.triton_silu_and_block_quant_backward", "modulename": "linghe.utils.silu", "qualname": "triton_silu_and_block_quant_backward", "kind": "function", "doc": "

\n", "signature": "(g, x, round_scale=False):", "funcdef": "def"}, {"fullname": "linghe.utils.silu.batch_weighted_silu_and_block_quant_forward_kernel", "modulename": "linghe.utils.silu", "qualname": "batch_weighted_silu_and_block_quant_forward_kernel", "kind": "function", "doc": "

\n", "signature": "(\tx_ptr,\tweight_ptr,\tout_ptr,\tscale_ptr,\ttranspose_output_ptr,\ttranspose_scale_ptr,\tcount_ptr,\taccum_ptr,\tn: int,\tE: int,\tROUND: int,\tOUTPUT_MODE: int):", "funcdef": "def"}, {"fullname": "linghe.utils.silu.triton_batch_weighted_silu_and_block_quant_forward", "modulename": "linghe.utils.silu", "qualname": "triton_batch_weighted_silu_and_block_quant_forward", "kind": "function", "doc": "

\n", "signature": "(\tx,\tweight,\tcounts,\tsplits=None,\tout=None,\tscale=None,\tround_scale=False,\toutput_mode=2):", "funcdef": "def"}, {"fullname": "linghe.utils.silu.batch_weighted_silu_and_block_quant_backward_kernel", "modulename": "linghe.utils.silu", "qualname": "batch_weighted_silu_and_block_quant_backward_kernel", "kind": "function", "doc": "

\n", "signature": "(\tg_ptr,\tx_ptr,\tweight_ptr,\tcount_ptr,\taccum_ptr,\tdx_ptr,\tdx_scale_ptr,\ttranspose_dx_ptr,\ttranspose_dx_scale_ptr,\tdw_ptr,\tn: int,\tE: int,\tROUND: int):", "funcdef": "def"}, {"fullname": "linghe.utils.silu.triton_batch_weighted_silu_and_block_quant_backward", "modulename": "linghe.utils.silu", "qualname": "triton_batch_weighted_silu_and_block_quant_backward", "kind": "function", "doc": "

\n", "signature": "(g, x, weight, counts, splits=None, round_scale=False):", "funcdef": "def"}, {"fullname": "linghe.utils.transpose", "modulename": "linghe.utils.transpose", "kind": "module", "doc": "

Copyright (c) Ant Financial Service Group and its affiliates.

\n"}, {"fullname": "linghe.utils.transpose.deprecated_transpose_kernel", "modulename": "linghe.utils.transpose", "qualname": "deprecated_transpose_kernel", "kind": "function", "doc": "

\n", "signature": "(x_ptr, t_ptr, M, N, H: int, W: int, EVEN: int):", "funcdef": "def"}, {"fullname": "linghe.utils.transpose.triton_depracated_transpose", "modulename": "linghe.utils.transpose", "qualname": "triton_depracated_transpose", "kind": "function", "doc": "

\n", "signature": "(x):", "funcdef": "def"}, {"fullname": "linghe.utils.transpose.transpose_kernel", "modulename": "linghe.utils.transpose", "qualname": "transpose_kernel", "kind": "function", "doc": "

\n", "signature": "(x_ptr, t_ptr, M, N, H: int, W: int, EVEN: int):", "funcdef": "def"}, {"fullname": "linghe.utils.transpose.transpose_dim_0_1_kernel", "modulename": "linghe.utils.transpose", "qualname": "transpose_dim_0_1_kernel", "kind": "function", "doc": "

\n", "signature": "(x_ptr, t_ptr, B, M, b_stride, m_stride, N: int):", "funcdef": "def"}, {"fullname": "linghe.utils.transpose.triton_transpose", "modulename": "linghe.utils.transpose", "qualname": "triton_transpose", "kind": "function", "doc": "

\n", "signature": "(x, dim0=None, dim1=None):", "funcdef": "def"}, {"fullname": "linghe.utils.transpose.transpose_and_pad_kernel", "modulename": "linghe.utils.transpose", "qualname": "transpose_and_pad_kernel", "kind": "function", "doc": "

\n", "signature": "(x_ptr, t_ptr, M, N, P, H: int, W: int, EVEN: int):", "funcdef": "def"}, {"fullname": "linghe.utils.transpose.triton_transpose_and_pad", "modulename": "linghe.utils.transpose", "qualname": "triton_transpose_and_pad", "kind": "function", "doc": "

\n", "signature": "(x, out=None, pad=True):", "funcdef": "def"}, {"fullname": "linghe.utils.transpose.batch_transpose_kernel", "modulename": "linghe.utils.transpose", "qualname": "batch_transpose_kernel", "kind": "function", "doc": "

\n", "signature": "(xs_ptr, xts_ptr, M, N, H: int, W: int):", "funcdef": "def"}, {"fullname": "linghe.utils.transpose.triton_batch_transpose", "modulename": "linghe.utils.transpose", "qualname": "triton_batch_transpose", "kind": "function", "doc": "

\n", "signature": "(xs, xts=None):", "funcdef": "def"}, {"fullname": "linghe.utils.transpose.batch_transpose_and_pad_kernel", "modulename": "linghe.utils.transpose", "qualname": "batch_transpose_and_pad_kernel", "kind": "function", "doc": "

\n", "signature": "(x_ptr, t_ptr, count_ptr, accum_ptr, pad_accum_ptr, N, H: int, W: int):", "funcdef": "def"}, {"fullname": "linghe.utils.transpose.triton_batch_transpose_and_pad", "modulename": "linghe.utils.transpose", "qualname": "triton_batch_transpose_and_pad", "kind": "function", "doc": "

\n", "signature": "(x, count_list, x_t=None, pad=True):", "funcdef": "def"}, {"fullname": "linghe.utils.transpose.configs", "modulename": "linghe.utils.transpose", "qualname": "configs", "kind": "variable", "doc": "

\n", "default_value": "[<triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>]"}, {"fullname": "linghe.utils.transpose.opt_transpose_kernel", "modulename": "linghe.utils.transpose", "qualname": "opt_transpose_kernel", "kind": "function", "doc": "

\n", "signature": "(x_ptr, t_ptr, M, N, D, H: int, W: int):", "funcdef": "def"}, {"fullname": "linghe.utils.transpose.triton_opt_transpose", "modulename": "linghe.utils.transpose", "qualname": "triton_opt_transpose", "kind": "function", "doc": "

\n", "signature": "(x):", "funcdef": "def"}]; + + // mirrored in build-search-index.js (part 1) + // Also split on html tags. this is a cheap heuristic, but good enough. + elasticlunr.tokenizer.setSeperator(/[\s\-.;&_'"=,()]+|<[^>]*>/); + + let searchIndex; + if (docs._isPrebuiltIndex) { + console.info("using precompiled search index"); + searchIndex = elasticlunr.Index.load(docs); + } else { + console.time("building search index"); + // mirrored in build-search-index.js (part 2) + searchIndex = elasticlunr(function () { + this.pipeline.remove(elasticlunr.stemmer); + this.pipeline.remove(elasticlunr.stopWordFilter); + this.addField("qualname"); + this.addField("fullname"); + this.addField("annotation"); + this.addField("default_value"); + this.addField("signature"); + this.addField("bases"); + this.addField("doc"); + this.setRef("fullname"); + }); + for (let doc of docs) { + searchIndex.addDoc(doc); + } + console.timeEnd("building search index"); + } + + return (term) => searchIndex.search(term, { + fields: { + qualname: {boost: 4}, + fullname: {boost: 2}, + annotation: {boost: 2}, + default_value: {boost: 2}, + signature: {boost: 2}, + bases: {boost: 2}, + doc: {boost: 1}, + }, + expand: true + }); +})(); \ No newline at end of file diff --git a/linghe/__init__.py b/linghe/__init__.py index e69de29..8b13789 100644 --- a/linghe/__init__.py +++ b/linghe/__init__.py @@ -0,0 +1 @@ + diff --git a/linghe/facade/add.py b/linghe/facade/add.py index 945ad0e..c40aca4 100644 --- a/linghe/facade/add.py +++ b/linghe/facade/add.py @@ -9,10 +9,25 @@ class InplaceAddFunction(torch.autograd.Function): + """ + + """ @staticmethod - def forward(ctx, x, y): + def forward(ctx, x: torch.Tensor, y: torch.Tensor): return triton_inplace_add(x, y) @staticmethod def backward(ctx, grad_output): return grad_output, grad_output + + +def inplace_add(x: torch.Tensor, y: torch.Tensor): + """ + inplace add y to x with mix precise + Args: + x: to be updated + y: add to x + Returns: + return updated x tensor + """ + return InplaceAddFunction.apply(x, y) \ No newline at end of file diff --git a/linghe/facade/fp32_linear.py b/linghe/facade/fp32_gemm.py similarity index 66% rename from linghe/facade/fp32_linear.py rename to linghe/facade/fp32_gemm.py index d356856..8ab9ae9 100644 --- a/linghe/facade/fp32_linear.py +++ b/linghe/facade/fp32_gemm.py @@ -10,9 +10,12 @@ triton_fp32_gemm_for_update) -class FusedFp32GEMM(torch.autograd.Function): +class Fp32GEMM(torch.autograd.Function): + """ + + """ @staticmethod - def forward(ctx, input, weight): + def forward(ctx, input: torch.Tensor, weight: torch.Tensor): shape = input.shape assert len(shape) == 3 input = input.view(shape[0] * shape[1], shape[2]) @@ -32,9 +35,22 @@ def backward(ctx, grad_output): grad_output = grad_output.view(shape[0] * shape[1], shape[2]) input, weight = ctx.saved_tensors - dx = triton_fp32_gemm_for_backward(grad_output, weight, accum=False) + dx = triton_fp32_gemm_for_backward(grad_output, weight) dx = dx.view(*ctx.shape) dw = triton_fp32_gemm_for_update(grad_output, input) return dx, dw + + +def fp32_gemm(input: torch.Tensor, weight: torch.Tensor): + """ + gemm with bf16/fp16 inputs and float32 output, + currently used in MoE router gemm. + Args: + input: bf16/fp16 activation tensor + weight: bf16/fp16 weight tensor + Returns: + output of gemm + """ + return Fp32GEMM.apply(input, weight) \ No newline at end of file diff --git a/linghe/facade/hadamard_quant_linear.py b/linghe/facade/hadamard_quant_linear.py new file mode 100644 index 0000000..586f104 --- /dev/null +++ b/linghe/facade/hadamard_quant_linear.py @@ -0,0 +1,163 @@ +# -*- coding: utf-8 -*- +""" +Copyright (c) Ant Financial Service Group and its affiliates. +""" + +import math +from typing import Optional + +import torch + +from linghe.quant.hadamard import triton_hadamard_quant + + + +class _HadamardQuantLinear(torch.autograd.Function): + @staticmethod + def forward( + ctx, + input: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + hadamard_matrix: torch.Tensor + ): + ctx.input_requires_grad = input.requires_grad + ctx.weight_requires_grad = weight.requires_grad + ctx.bias_requires_grad = bias is not None and bias.requires_grad + + ctx.out_dtype = input.dtype + ctx.input_shape = input.shape + input = input.view(-1, input.shape[-1]) + + x_q, x_scale, xt_q, xt_scale = triton_hadamard_quant(input, hadamard_matrix) + w_q, w_scale, wt_q, wt_scale = triton_hadamard_quant(weight, hadamard_matrix) + + output = torch._scaled_mm(x_q, + w_q.t(), + scale_a=x_scale, + scale_b=w_scale, + out_dtype=ctx.out_dtype, + use_fast_accum=True + ) + + if bias is not None: + output += bias + + saved_tensors = [ + xt_q if ctx.weight_requires_grad else None, + xt_scale if ctx.weight_requires_grad else None, + wt_q if ctx.input_requires_grad else None, + wt_scale if ctx.input_requires_grad else None, + hadamard_matrix if ctx.weight_requires_grad or ctx.weight_requires_grad else None + ] + + ctx.save_for_backward(*saved_tensors) + out_shape = (*ctx.input_shape[0:-1], -1) + return output.view(out_shape) + + @staticmethod + def backward( + ctx, + output_grad: torch.Tensor, + ): + xt_q, xt_scale, wt_q, wt_scale, hadamard_matrix = ctx.saved_tensors + results = [None, None, None, None] + + output_grad = output_grad.view(-1, output_grad.shape[-1]) + + y_q, y_scale, yt_q, yt_scale = triton_hadamard_quant(output_grad, hadamard_matrix) + + dx = torch._scaled_mm(y_q, + wt_q.t(), + scale_a=y_scale, + scale_b=wt_scale, + out_dtype=ctx.out_dtype, + use_fast_accum=True + ) + + # calculate input grad and assign to results[0] + results[0] = dx.view(ctx.input_shape) + + # calculate weight grad and assign to results[1] + dw = torch._scaled_mm(yt_q, + xt_q.t(), + scale_a=yt_scale, + scale_b=xt_scale, + out_dtype=ctx.out_dtype, + use_fast_accum=True + ) + results[1] = dw + + if ctx.bias_requires_grad: + # calculate bias grad and assign to results[2] + results[2] = torch.sum(output_grad, dim=0) + + return tuple(results) + +class HadamardQuantLinear(torch.nn.Module): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None + ): + """ + a naive implementation of hadamard transformation and quantization + Args: + in_features: in feature number + out_features: out feature number + bias: whether use bias + device: weight device + dtype: weight dtype + impl: implementation of hadamard quantization + """ + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = torch.nn.parameter.Parameter( + torch.empty((out_features, in_features), device=device, + dtype=dtype)) + if bias: + self.bias = torch.nn.parameter.Parameter( + torch.empty(out_features, device=device, dtype=dtype)) + else: + self.bias = None + + size = 32 if 'H20' in torch.cuda.get_device_properties(0).name else 64 + data = self._hadamard_matrix(size, device=device, dtype=dtype, + norm=True) + self.hadamard_matrix = torch.nn.parameter.Parameter(data, + requires_grad=False) + self.reset_parameters() + + def _hadamard_matrix(self, size, device=None, dtype=None, norm=False): + assert 2 ** int(math.log2(size)) == size + m2 = torch.tensor([[1, 1], [1, -1]], device=device, dtype=torch.float32) + m = m2 + for _ in range(int(math.log2(size)) - 1): + m = torch.kron(m, m2) + if norm: + m = m / size ** 0.5 + if dtype is not None: + m = m.to(dtype) + return m + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if self.training: + return _HadamardQuantLinear.apply(input, self.weight, self.bias, + self.hadamard_matrix) + else: + output = input @ self.weight.t() + if self.bias is not None: + output = output + self.bias + return output + + def extra_repr(self) -> str: + return f"in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}" + + def reset_parameters(self): + self.weight.data.normal_(mean=0.0, std=0.02) + if self.bias is not None: + self.bias.data.zero_() diff --git a/linghe/facade/loss.py b/linghe/facade/loss.py index fff59dd..1fa7294 100644 --- a/linghe/facade/loss.py +++ b/linghe/facade/loss.py @@ -10,6 +10,9 @@ class SoftmaxCrossEntropyFunction(torch.autograd.Function): + """ + + """ @staticmethod def forward(ctx, logits, labels, inplace=False): shape = logits.shape @@ -38,7 +41,26 @@ def backward(ctx, grad_output): return grad, None, None, None +def softmax_cross_entropy(logits: torch.Tensor, labels: torch.Tensor, inplace: bool = False): + """ + softmax cross entropy + Args: + logits: logits tensor, shape [...,dim] + labels: labels tensor, shape [...] + inplace: update gradient in the `logits` tensor if True + + Returns: + per token loss + """ + assert logits.is_contiguous() + assert labels.is_contiguous() + return SoftmaxCrossEntropyFunction.apply(logits, labels, inplace) + + class GradScalingFunction(torch.autograd.Function): + """ + + """ @staticmethod def forward(ctx, x, coef=0.2): ctx.coef = coef diff --git a/linghe/facade/norm.py b/linghe/facade/norm.py index c5d31ad..435942f 100644 --- a/linghe/facade/norm.py +++ b/linghe/facade/norm.py @@ -10,6 +10,9 @@ class RMSNormFunction(torch.autograd.Function): + """ + + """ @staticmethod def forward(ctx, x, weight, eps=1e-6): output = triton_rms_norm_forward( @@ -37,17 +40,35 @@ def backward(ctx, dy): return dx, dw, None +def rms_norm(x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6): + """ + rms norm of x with weight + Args: + x: activation tensor + weight: weight tensor + eps: epsilon for RMS + + Returns: + rms output + """ + assert x.contiguous() + assert weight.contiguous() + return RMSNormFunction.apply(x, weight, eps) + class GroupNormGateFunction(torch.autograd.Function): + """ + + """ @staticmethod - def forward(ctx, x, gate, weight, eps=1e-6, group_size=4): + def forward(ctx, attn_output, gate, weight, eps=1e-6, group_size=4): output = triton_group_norm_gate_forward( - x, + attn_output, gate, weight.data, eps=eps, group_size=group_size ) - ctx.save_for_backward(x, gate, weight.data) + ctx.save_for_backward(attn_output, gate, weight.data) ctx.eps = eps ctx.group_size = group_size @@ -55,11 +76,11 @@ def forward(ctx, x, gate, weight, eps=1e-6, group_size=4): @staticmethod def backward(ctx, dy): - x, gate, weight = ctx.saved_tensors + attn_output, gate, weight = ctx.saved_tensors dx, dg, dw = triton_group_norm_gate_backward( dy, - x, + attn_output, gate, weight, ctx.eps, @@ -67,3 +88,23 @@ def backward(ctx, dy): ) return dx, dg, dw, None, None + + + +def group_norm_gate(attn_output: torch.Tensor, + gate: torch.Tensor, + weight: torch.Tensor, + eps: float = 1e-6, + group_size: int = 4): + """ + return group_rms_norm(transpose(attn_output, [0,1]), weight) * sigmoid(gate) + Args: + attn_output: output of core attn, shape [bs, length, n_heads, head_dim] + gate: gate tensor for attention output, shape [length, bs, dim] + weight: weight of RMS norm, shape [dim] + eps: epsilon for RMS + group_size: group size of group RMS norm + Returns: + output with shape [length, bs, dim] + """ + return GroupNormGateFunction.apply(attn_output, gate, weight, eps, group_size) \ No newline at end of file diff --git a/linghe/facade/rope.py b/linghe/facade/rope.py index f12d518..3719095 100644 --- a/linghe/facade/rope.py +++ b/linghe/facade/rope.py @@ -10,6 +10,9 @@ class QkNormHalfRopeFunction(torch.autograd.Function): + """ + + """ @staticmethod def forward(ctx, qkv, q_norm_weight, k_norm_weight, freqs, H=32, h=4, eps=1e-6): @@ -47,3 +50,35 @@ def backward(ctx, grad_q, grad_k, grad_v): transpose=True, interleave=True) return dqkv, dqw, dkw, None, None, None, None + + +def qk_norm_half_rope(qkv: torch.Tensor, + q_norm_weight: torch.Tensor, + k_norm_weight: torch.Tensor, + freqs: torch.Tensor, + H: int = 32, + h: int = 4, + eps: float = 1e-6): + """ + split qkv to q/k/v, apply qk norm and half rope to q/k, transpose q/k/v to flash-attention layout + Args: + qkv: QKV tensor with size of [S, B, dim], heads are interleaved + q_norm_weight: rms norm weight for query + k_norm_weight: rms norm weight for key + freqs: Freqs tensor based on half dim. + H: Number of attention heads. + h: Number of key/value heads. + eps: epsilon value for L2 normalization. + + Returns: + qo: shape [B, S, H, head_dim] + ko: shape [B, S, h, head_dim] + vo: shape [B, S, h, head_dim] + """ + return QkNormHalfRopeFunction.apply(qkv, + q_norm_weight, + k_norm_weight, + freqs, + H, + h, + eps) \ No newline at end of file diff --git a/linghe/facade/smooth_quant_linear.py b/linghe/facade/smooth_quant_linear.py new file mode 100644 index 0000000..fccbabb --- /dev/null +++ b/linghe/facade/smooth_quant_linear.py @@ -0,0 +1,154 @@ +# -*- coding: utf-8 -*- +""" +Copyright (c) Ant Financial Service Group and its affiliates. +""" + +from typing import Optional + +import torch + + +from linghe.quant.smooth import triton_smooth_quant, \ + triton_transpose_smooth_quant +from linghe.utils.transpose import triton_transpose_and_pad +from linghe.utils.reduce import triton_abs_max + +class _SmoothQuantLinear(torch.autograd.Function): + @staticmethod + def forward( + ctx, + input: torch.Tensor, + weight: torch.Tensor, + smooth_scale: torch.Tensor, + bias: Optional[torch.Tensor] + ): + ctx.input_requires_grad = input.requires_grad + ctx.weight_requires_grad = weight.requires_grad + ctx.bias_requires_grad = bias is not None and bias.requires_grad + + ctx.out_dtype = input.dtype + ctx.input_shape = input.shape + input = input.view(-1, input.shape[-1]) + + x_q, x_scale, x_maxs = triton_smooth_quant(input, 1 / smooth_scale) + w_q, w_scale, w_maxs = triton_smooth_quant(weight, smooth_scale) + + output = torch._scaled_mm(x_q, + w_q.t(), + scale_a=x_scale.view(-1, 1), + scale_b=w_scale.view(1, -1), + out_dtype=ctx.out_dtype, + use_fast_accum=True) + + if bias is not None: + output += bias + + saved_tensors = [ + x_q if ctx.weight_requires_grad else None, + x_scale if ctx.weight_requires_grad else None, + w_q if ctx.input_requires_grad else None, + w_scale if ctx.input_requires_grad else None, + smooth_scale if ctx.weight_requires_grad or ctx.weight_requires_grad else None + ] + + ctx.save_for_backward(*saved_tensors) + out_shape = (*ctx.input_shape[0:-1], -1) + return output.view(out_shape) + + @staticmethod + def backward( + ctx, + output_grad: torch.Tensor + ): + x_q, x_s, w_q, w_s, smooth_scale = ctx.saved_tensors + results = [None, None, None, None] + + output_grad = output_grad.view(-1, output_grad.shape[-1]) + + y_q, y_scale, y_maxs = triton_smooth_quant(output_grad, w_s) + + wt_q = triton_transpose_and_pad(w_q, pad=True) + dx = torch._scaled_mm(y_q, + wt_q.t(), + scale_a=y_scale.view(-1, 1), + scale_b=smooth_scale.view(1, -1), + out_dtype=ctx.out_dtype, + use_fast_accum=True) + + # calculate input grad and assign to results[0] + results[0] = dx.view(ctx.input_shape) + + # calculate weight grad and assign to results[1] + yt_q, yt_scale, yt_maxs = triton_transpose_smooth_quant(output_grad, x_s) + + xt_q = triton_transpose_and_pad(x_q, pad=True) + dw = torch._scaled_mm(yt_q, + xt_q.t(), + scale_a=yt_scale.view(-1, 1), + scale_b=1/smooth_scale.view(1, -1), + out_dtype=ctx.out_dtype, + use_fast_accum=True) + + results[1] = dw + + if ctx.bias_requires_grad: + # calculate bias grad and assign to results[2] + results[2] = torch.sum(output_grad, dim=0) + + return tuple(results) + + +class QuantLinear(torch.nn.Module): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = torch.nn.parameter.Parameter( + torch.empty((out_features, in_features), device=device, + dtype=dtype)) + if bias: + self.bias = torch.nn.parameter.Parameter( + torch.empty(out_features, device=device, dtype=dtype)) + else: + self.bias = None + + self.gap_step = 16 + self.decay_coef = 0.9 + self.smooth_scale = None + self.smooth_update_step = 0 + + self.reset_parameters() + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if self.training: + + if self.smooth_update_step % self.gap_step == 0: + input_maxs = triton_abs_max(input) + weight_maxs = triton_abs_max(self.weight.data) + self.smooth_scale = torch.sqrt(input_maxs * weight_maxs) + + output, smooth_scale = _SmoothQuantLinear.apply(input, + self.weight, + self.bias, + self.smooth_scale) + self.smooth_update_step += 1 + else: + output = input @ self.weight.t() + if self.bias is not None: + output = output + self.bias + return output + + def extra_repr(self) -> str: + return f"in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}" + + def reset_parameters(self): + self.weight.data.normal_(mean=0.0, std=0.02) + if self.bias is not None: + self.bias.data.zero_() diff --git a/linghe/facade/transpose.py b/linghe/facade/transpose.py index d3ecfaa..9d8de83 100644 --- a/linghe/facade/transpose.py +++ b/linghe/facade/transpose.py @@ -9,6 +9,9 @@ class TransposeDim01Function(torch.autograd.Function): + """ + + """ @staticmethod def forward(ctx, x): return triton_transpose(x, dim0=0, dim1=1) @@ -16,3 +19,15 @@ def forward(ctx, x): @staticmethod def backward(ctx, grad_output): return triton_transpose(grad_output, dim0=0, dim1=1) + + +def transpose_dim01(x): + """ + transpose a tensor with the first two dims, x.ndims should not greater than 4 + Args: + x: input tensor + + Returns: + a transposed tensor + """ + return TransposeDim01Function.apply(x) \ No newline at end of file diff --git a/linghe/gemm/blockwise_fp8_gemm.py b/linghe/gemm/blockwise_fp8_gemm.py new file mode 100644 index 0000000..da9416a --- /dev/null +++ b/linghe/gemm/blockwise_fp8_gemm.py @@ -0,0 +1,242 @@ +# -*- coding: utf-8 -*- +""" +Copyright (c) Ant Financial Service Group and its affiliates. +""" + +import torch +import triton +import triton.language as tl +from triton import Config + +fp8_gemm_configs = [ + Config({"BLOCK_SIZE_M": block_m, "BLOCK_SIZE_N": block_n}, + num_stages=num_stages, num_warps=8) + for block_m in [32, 64, 128] + for block_n in [32, 64, 128] + for num_stages in [3, 4, 5, 6] +] + + +# @triton.autotune(configs=fp8_gemm_configs, key=["N", "K"]) +@triton.jit +def fp8_gemm_bb_kernel( + a_ptr, + b_ptr, + c_ptr, + a_s_ptr, + b_s_ptr, + M, + N: tl.constexpr, + K: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + # a blockwise quantization, b blockwise quantization. + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + k = tl.cdiv(K, BLOCK_SIZE_K) + offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :] + # b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None] + b_ptrs = b_ptr + offs_n[:, None] * K + offs_k[None, :] + nb = K // BLOCK_SIZE_K + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for i in range(0, k): + a_s = tl.load(a_s_ptr + pid_m * nb + i) + b_s = tl.load(b_s_ptr + pid_n * nb + i) + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, + other=0.0) + b = tl.load(b_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, + other=0.0) + accumulator += tl.dot(a, tl.trans(b)) * (a_s * b_s) + # accumulator = tl.dot(a, tl.trans(b), accumulator) + # accumulator += (accumulators-accumulator) * scale + a_ptrs += BLOCK_SIZE_K + b_ptrs += BLOCK_SIZE_K + c = accumulator.to(c_ptr.dtype.element_ty) + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :] + mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + tl.store(c_ptrs, c, mask=mask) + + + +def triton_bb_fp8_gemm(a: torch.Tensor, + b: torch.Tensor, + a_s: torch.Tensor, + b_s: torch.Tensor, + out_dtype=torch.bfloat16, + block_size=128): + assert a.is_contiguous() and b.is_contiguous() + assert a_s.is_contiguous() and b_s.is_contiguous() + K = a.size(-1) + M = a.numel() // K + N = b.size(0) + c = torch.empty(M, N, dtype=out_dtype, device=a.device) + grid = lambda META: (triton.cdiv(M, META["BLOCK_SIZE_M"]), + triton.cdiv(N, META["BLOCK_SIZE_N"])) # noqa + + fp8_gemm_bb_kernel[grid](a, b, c, a_s, b_s, + M, N, K, + BLOCK_SIZE_K=block_size, + BLOCK_SIZE_M=block_size, + BLOCK_SIZE_N=block_size, + num_warps=8, + num_stages=4 + ) + return c + + + +@triton.autotune(configs=fp8_gemm_configs, key=["N", "K"]) +@triton.jit +def fp8_gemm_tb_kernel( + a_ptr, + b_ptr, + c_ptr, + a_s_ptr, + b_s_ptr, + M, + N: tl.constexpr, + K: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + # a tilewise quantization, b blockwise quantization. + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + k = tl.cdiv(K, BLOCK_SIZE_K) + offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :] + b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None] + a_s_ptrs = a_s_ptr + offs_m * k + b_s_ptrs = b_s_ptr + (offs_n // BLOCK_SIZE_K) * k + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for i in range(k): + # a = tl.load(a_ptrs) + # b = tl.load(b_ptrs) + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, + other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K, + other=0.0) + a_s = tl.load(a_s_ptrs) + b_s = tl.load(b_s_ptrs) + # accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :] + accumulators = tl.dot(a, b, accumulator) + accumulator += (accumulators - accumulator) * a_s[:, None] * b_s[None, + :] + a_ptrs += BLOCK_SIZE_K + b_ptrs += BLOCK_SIZE_K + a_s_ptrs += 1 + b_s_ptrs += 1 + c = accumulator.to(c_ptr.dtype.element_ty) + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :] + # tl.store(c_ptrs, c) + mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + tl.store(c_ptrs, c, mask=mask) + + + + +def triton_tb_fp8_gemm(a: torch.Tensor, + b: torch.Tensor, + a_s: torch.Tensor, + b_s: torch.Tensor, + out_dtype=torch.bfloat16, + block_size=128): + assert a.is_contiguous() and b.is_contiguous() + assert a_s.is_contiguous() and b_s.is_contiguous() + K = a.size(-1) + M = a.numel() // K + N = b.size(0) + c = torch.empty(M, N, dtype=out_dtype, device=a.device) + grid = lambda META: (triton.cdiv(M, META["BLOCK_SIZE_M"]), + triton.cdiv(N, META["BLOCK_SIZE_N"])) # noqa + + fp8_gemm_tb_kernel[grid](a, b, c, + a_s, b_s, + M, N, K, + block_size + ) + return c + + +@triton.autotune(configs=fp8_gemm_configs, key=["N", "K"]) +@triton.jit +def fp8_gemm_tt_kernel( + a_ptr, + b_ptr, + c_ptr, + a_s_ptr, + b_s_ptr, + M, + N: tl.constexpr, + K: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + # a and b all tilewise quantization. + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + k = tl.cdiv(K, BLOCK_SIZE_K) + offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :] + b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None] + a_s_ptrs = a_s_ptr + offs_m * k + b_s_ptrs = b_s_ptr + offs_n * k + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for i in range(k): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, + other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K, + other=0.0) + a_s = tl.load(a_s_ptrs) + b_s = tl.load(b_s_ptrs) + accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :] + a_ptrs += BLOCK_SIZE_K + b_ptrs += BLOCK_SIZE_K + a_s_ptrs += 1 + b_s_ptrs += 1 + + c = accumulator.to(c_ptr.dtype.element_ty) + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :] + mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + tl.store(c_ptrs, c, mask=mask) + + +def triton_tt_fp8_gemm(a: torch.Tensor, + b: torch.Tensor, + a_s: torch.Tensor, + b_s: torch.Tensor, + out_dtype=torch.bfloat16, + block_size=128): + assert a.is_contiguous() and b.is_contiguous() + assert a_s.is_contiguous() and b_s.is_contiguous() + K = a.size(-1) + M = a.numel() // K + N = b.size(0) + c = torch.empty(*a.size()[:-1], N, dtype=out_dtype, device=a.device) + grid = lambda META: (triton.cdiv(M, META["BLOCK_SIZE_M"]), + triton.cdiv(N, META["BLOCK_SIZE_N"])) # noqa + fp8_gemm_tt_kernel[grid](a, b, c, + a_s, b_s, + M, N, K, + block_size) + return c diff --git a/linghe/gemm/channelwise_fp8_gemm.py b/linghe/gemm/channelwise_fp8_gemm.py new file mode 100644 index 0000000..a4d978f --- /dev/null +++ b/linghe/gemm/channelwise_fp8_gemm.py @@ -0,0 +1,130 @@ +# -*- coding: utf-8 -*- +""" +Copyright (c) Ant Financial Service Group and its affiliates. +""" + +import torch +import triton +import triton.language as tl + + +# os.environ["TRITON_PRINT_AUTOTUNING"] = "1" + + +# fp8_gemm_configs = [ +# Config({"BLOCK_SIZE_K": block_k, "BLOCK_SIZE_M": block_m, "BLOCK_SIZE_N": block_n}, num_stages=num_stages, num_warps=num_warps) +# for block_k in [64, 128, 256] +# for block_m in [64, 128, 256] +# for block_n in [64, 128, 256] +# for num_stages in [2, 3, 4, 5] +# for num_warps in [4, 8, 16] +# # for num_stages in [3] +# # for num_warps in [8] +# ] + +# @triton.autotune(configs=fp8_gemm_configs, key=["M", "N", "K"]) +@triton.jit +def scaled_mm_kernel( + a_ptr, + b_ptr, + c_ptr, + a_scale_ptr, + b_scale_ptr, + N, + K, + ACCUM: tl.constexpr, + EVEN: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + k = tl.cdiv(K, BLOCK_SIZE_K) + offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) + offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :] + b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None] + a_scale = tl.load(a_scale_ptr + offs_m) + b_scale = tl.load(b_scale_ptr + offs_n) + + if ACCUM: + c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :] + accumulator = tl.load(c_ptrs).to(tl.float32) + a_s = 1 / tl.maximum(a_scale, 1e-30) + b_s = 1 / tl.maximum(b_scale, 1e-30) + accumulator = accumulator * a_s[:, None] * b_s[None, :] + else: + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + if EVEN: + for i in range(k): + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) + accumulator = tl.dot(a, b, accumulator) + a_ptrs += BLOCK_SIZE_K + b_ptrs += BLOCK_SIZE_K + else: + for i in range(k): + indices = i * BLOCK_SIZE_K + offs_k + a = tl.load(a_ptrs, mask=indices[None, :] < K) + b = tl.load(b_ptrs, mask=indices[:, None] < K) + accumulator = tl.dot(a, b, accumulator) + a_ptrs += BLOCK_SIZE_K + b_ptrs += BLOCK_SIZE_K + + accumulator = accumulator * a_scale[:, None] * b_scale[None, :] + accumulator = accumulator.to(c_ptr.dtype.element_ty) + c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :] + tl.store(c_ptrs, accumulator) + + +def triton_scaled_mm(a: torch.Tensor, + b: torch.Tensor, + a_scale: torch.Tensor, + b_scale: torch.Tensor, + out_dtype=torch.float32, + c=None, + accum=True): + """ + similar to torch._scaled_mm, support accumulating gemm output to c + and low precision output tensor + Args: + a: left fp8 tensor + b: right fp8 tensor, column-major + a_scale: fp32 scale of a + b_scale: fp32 scale of b + out_dtype: output tensor dtype + c: output tensor + accum: accumulate output on c if True + + Returns: + c: output tensor + """ + assert a.is_contiguous() and b.is_contiguous() + M, K = a.size() + N, K = b.size() + ACCUM = accum and c is not None + if c is None: + c = torch.empty(M, N, dtype=out_dtype, device=a.device) + BLOCK_SIZE_K = 128 + BLOCK_SIZE_M = 128 + BLOCK_SIZE_N = 256 + EVEN = K % BLOCK_SIZE_K == 0 + grid = lambda META: ( + M // META["BLOCK_SIZE_M"], N // META["BLOCK_SIZE_N"]) # noqa + scaled_mm_kernel[grid](a, b, c, + a_scale, + b_scale, + N, K, + ACCUM, + EVEN, + BLOCK_SIZE_K, + BLOCK_SIZE_M, + BLOCK_SIZE_N, + num_stages=3, + num_warps=8 + ) + + return c diff --git a/linghe/gemm/fp32_gemm.py b/linghe/gemm/fp32_gemm.py index e4bbeb9..8f44067 100644 --- a/linghe/gemm/fp32_gemm.py +++ b/linghe/gemm/fp32_gemm.py @@ -59,10 +59,19 @@ def fp32_gemm_kernel( tl.store(c_ptrs, c) -# a, bf16 -# b, bf16 -# c, fp32 + def triton_fp32_gemm(a: torch.Tensor, b: torch.Tensor): + """ + return fp32 gemm result with fp16/bf16 inputs, + it's mainly used for MoE router GEMM + and DO NOT suitable for large size GEMM + Args: + a: left matrix with fp16/bf16 precision + b: right matrix with fp16/bf16 precision + + Returns: + c: output with fp32 precision + """ assert a.is_contiguous() and b.is_contiguous() M, K = a.size() N, K = b.size() @@ -86,11 +95,11 @@ def triton_fp32_gemm(a: torch.Tensor, b: torch.Tensor): return c +# @triton.autotune(configs=fp32_gemm_configs, key=["M", "N", "K"]) @triton.jit -def scaled_fp32_gemm_kernel( +def fp32_gemm_for_backward_kernel( a_ptr, b_ptr, - scale_ptr, c_ptr, M, N: tl.constexpr, @@ -106,61 +115,64 @@ def scaled_fp32_gemm_kernel( offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) offs_k = tl.arange(0, BLOCK_SIZE_K) a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :] - b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None] + b_ptrs = b_ptr + offs_n[None, :] + offs_k[:, None] * N c = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for i in range(k): - a = tl.load(a_ptrs).to(tl.float32) + a = tl.load(a_ptrs) b = tl.load(b_ptrs).to(tl.float32) # c += tl.dot(a, b) c = tl.dot(a, b, c) a_ptrs += BLOCK_SIZE_K - b_ptrs += BLOCK_SIZE_K - - scale = tl.load( - scale_ptr + pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) - c *= scale[:, None] - + b_ptrs += BLOCK_SIZE_K * N offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :] tl.store(c_ptrs, c) -def triton_scaled_fp32_gemm(a: torch.Tensor, b: torch.Tensor, - scale: torch.Tensor): +def triton_fp32_gemm_for_backward(a: torch.Tensor, + b: torch.Tensor): + """ + mix precision gemm for backward, a@b.float() + Args: + a: input gradient, fp32 + b: gemm weight, bf16/fp16 + Returns: + c: gradient of activation + """ assert a.is_contiguous() and b.is_contiguous() M, K = a.size() - N, K = b.size() - c = torch.empty(M, N, dtype=torch.float32, device=a.device) + K, N = b.size() + c = torch.empty((M, N), dtype=b.dtype, device=b.device) grid = lambda META: (triton.cdiv(M, META["BLOCK_SIZE_M"]), triton.cdiv(N, META["BLOCK_SIZE_N"])) # noqa BLOCK_SIZE_K = 128 BLOCK_SIZE_M = 32 BLOCK_SIZE_N = 128 num_warps = 4 - num_stages = 3 - scaled_fp32_gemm_kernel[grid](a, b, scale, c, - M, N, K, - BLOCK_SIZE_K, - BLOCK_SIZE_M, - BLOCK_SIZE_N, - num_warps=num_warps, - num_stages=num_stages - ) + num_stages = 2 + fp32_gemm_for_backward_kernel[grid](a, b, c, + M, N, K, + BLOCK_SIZE_K, + BLOCK_SIZE_M, + BLOCK_SIZE_N, + num_warps=num_warps, + num_stages=num_stages + ) return c # @triton.autotune(configs=fp32_gemm_configs, key=["M", "N", "K"]) @triton.jit -def fp32_gemm_for_backward_kernel( +def fp32_gemm_for_update_kernel( a_ptr, b_ptr, c_ptr, M, N: tl.constexpr, K: tl.constexpr, - ACCUM: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, @@ -171,63 +183,62 @@ def fp32_gemm_for_backward_kernel( offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :] + a_ptrs = a_ptr + offs_m[None, :] + offs_k[:, None] * M b_ptrs = b_ptr + offs_n[None, :] + offs_k[:, None] * N - if ACCUM: - c = tl.load(c_ptr + offs_m[:, None] * N + offs_n[None, :]).to( - tl.float32) - else: - c = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - + c = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + # c = tl.load(c_ptr + offs_m[:, None] * N + offs_n[None, :]).to(tl.float32) for i in range(k): - a = tl.load(a_ptrs) + a = tl.trans(tl.load(a_ptrs)).to(tl.float32) b = tl.load(b_ptrs).to(tl.float32) # c += tl.dot(a, b) c = tl.dot(a, b, c) - a_ptrs += BLOCK_SIZE_K + a_ptrs += BLOCK_SIZE_K * M b_ptrs += BLOCK_SIZE_K * N + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :] tl.store(c_ptrs, c) -# a: router output, fp32 -# b: router weight, bf16, should be transposed before calculation -# c: dy of rms, bf16, shoule be accumlated -def triton_fp32_gemm_for_backward(a: torch.Tensor, b: torch.Tensor, - c: Optional[torch.Tensor] = None, - accum=False): +def triton_fp32_gemm_for_update(a: torch.Tensor, b: torch.Tensor): + """ + mix precision gemm for updaing weight + Args: + a: gradient of output, fp32 + b: input activation, bf16/fp16 + Returns: + c: gradient of weight + """ assert a.is_contiguous() and b.is_contiguous() - M, K = a.size() + K, M = a.size() K, N = b.size() - if c is None: - c = torch.empty((M, N), dtype=b.dtype, device=b.device) - accum = False + c = torch.empty((M, N), dtype=b.dtype, device=b.device) grid = lambda META: (triton.cdiv(M, META["BLOCK_SIZE_M"]), triton.cdiv(N, META["BLOCK_SIZE_N"])) # noqa BLOCK_SIZE_K = 128 BLOCK_SIZE_M = 32 BLOCK_SIZE_N = 128 num_warps = 4 - num_stages = 2 - fp32_gemm_for_backward_kernel[grid](a, b, c, - M, N, K, accum, - BLOCK_SIZE_K, - BLOCK_SIZE_M, - BLOCK_SIZE_N, - num_warps=num_warps, - num_stages=num_stages - ) + num_stages = 3 + fp32_gemm_for_update_kernel[grid](a, b, c, + M, N, K, + BLOCK_SIZE_K, + BLOCK_SIZE_M, + BLOCK_SIZE_N, + num_warps=num_warps, + num_stages=num_stages + ) return c -# @triton.autotune(configs=fp32_gemm_configs, key=["M", "N", "K"]) + @triton.jit -def fp32_gemm_for_update_kernel( +def scaled_fp32_gemm_kernel( a_ptr, b_ptr, + scale_ptr, c_ptr, M, N: tl.constexpr, @@ -242,18 +253,21 @@ def fp32_gemm_for_update_kernel( offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + offs_m[None, :] + offs_k[:, None] * M - b_ptrs = b_ptr + offs_n[None, :] + offs_k[:, None] * N + a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :] + b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None] c = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - # c = tl.load(c_ptr + offs_m[:, None] * N + offs_n[None, :]).to(tl.float32) for i in range(k): - a = tl.trans(tl.load(a_ptrs)).to(tl.float32) + a = tl.load(a_ptrs).to(tl.float32) b = tl.load(b_ptrs).to(tl.float32) # c += tl.dot(a, b) c = tl.dot(a, b, c) - a_ptrs += BLOCK_SIZE_K * M - b_ptrs += BLOCK_SIZE_K * N + a_ptrs += BLOCK_SIZE_K + b_ptrs += BLOCK_SIZE_K + + scale = tl.load( + scale_ptr + pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) + c *= scale[:, None] offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) @@ -261,13 +275,34 @@ def fp32_gemm_for_update_kernel( tl.store(c_ptrs, c) -# a: router output, fp32, should be transposed before calculation -# b: input of rms, bf16, should be transposed before calculation -def triton_fp32_gemm_for_update(a: torch.Tensor, b: torch.Tensor): +def triton_scaled_fp32_gemm(a: torch.Tensor, + b: torch.Tensor, + scale: torch.Tensor): + """ + c = (a*scale[:,None])*b + this kernel is used to fuse RMSNorm and quantization in MoE layer + native implementation: + y = rms_norm(x), + y_q = quantization(y), + router_logits = y@w + we can not fuse rms_norm and quantization + as we still need bf16 y for moe router gemm + fused implementation: + y_q, rms = quantization(rms_norm(x)) + router_logits = (x/rms)@y + so we need a scaled fp32 gemm kernel + Args: + a: activation tensor + b: weight tensor + scale: scale for activation tensor, 1/rms + + Returns: + + """ assert a.is_contiguous() and b.is_contiguous() - K, M = a.size() - K, N = b.size() - c = torch.empty((M, N), dtype=b.dtype, device=b.device) + M, K = a.size() + N, K = b.size() + c = torch.empty(M, N, dtype=torch.float32, device=a.device) grid = lambda META: (triton.cdiv(M, META["BLOCK_SIZE_M"]), triton.cdiv(N, META["BLOCK_SIZE_N"])) # noqa BLOCK_SIZE_K = 128 @@ -275,17 +310,20 @@ def triton_fp32_gemm_for_update(a: torch.Tensor, b: torch.Tensor): BLOCK_SIZE_N = 128 num_warps = 4 num_stages = 3 - fp32_gemm_for_update_kernel[grid](a, b, c, - M, N, K, - BLOCK_SIZE_K, - BLOCK_SIZE_M, - BLOCK_SIZE_N, - num_warps=num_warps, - num_stages=num_stages - ) + scaled_fp32_gemm_kernel[grid](a, b, + scale, + c, + M, N, K, + BLOCK_SIZE_K, + BLOCK_SIZE_M, + BLOCK_SIZE_N, + num_warps=num_warps, + num_stages=num_stages + ) return c + @triton.jit def scaled_fp32_gemm_for_update_kernel( a_ptr, @@ -309,7 +347,6 @@ def scaled_fp32_gemm_for_update_kernel( b_ptrs = b_ptr + offs_n[None, :] + offs_k[:, None] * N c = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - # c = tl.load(c_ptr + offs_m[:, None] * N + offs_n[None, :]).to(tl.float32) for i in range(k): scale = tl.load( scale_ptr + i * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)) @@ -326,11 +363,19 @@ def scaled_fp32_gemm_for_update_kernel( tl.store(c_ptrs, c) -# a: router output, fp32, should be transposed before calculation -# b: input of rms, bf16, should be transposed before calculation -# scale: 1/rms -def triton_scaled_fp32_gemm_for_update(a: torch.Tensor, b: torch.Tensor, +def triton_scaled_fp32_gemm_for_update(a: torch.Tensor, + b: torch.Tensor, scale: torch.Tensor): + """ + see triton_scaled_fp32_gemm + Args: + a: y + b: activation before RMS norm + scale: 1/rms + + Returns: + dw + """ assert a.is_contiguous() and b.is_contiguous() K, M = a.size() K, N = b.size() diff --git a/linghe/quant/block/block.py b/linghe/quant/block.py similarity index 85% rename from linghe/quant/block/block.py rename to linghe/quant/block.py index cdb9d16..cfb37fa 100644 --- a/linghe/quant/block/block.py +++ b/linghe/quant/block.py @@ -28,9 +28,20 @@ def block_quant_kernel(x_ptr, y_ptr, s_ptr, M, N, BLOCK_SIZE: tl.constexpr, tl.store(s_ptr + pid_m * n + pid_n, s) -def block_quant(x, +def triton_block_quant(x, block_size=128, round_scale=False): + """ + blockwise quantize x + Args: + x: input tensor + block_size: block wise + round_scale: whether round scale to power of 2 + + Returns: + y: quantized tensor, float8_e4m3fn + s: quantization scale, float32 + """ M, N = x.size() y = torch.empty((M, N), dtype=torch.float8_e4m3fn, device=x.device) s = x.new_empty(x.size(-2) // block_size, x.size(-1) // block_size, diff --git a/linghe/quant/block/__init__.py b/linghe/quant/block/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/linghe/quant/block/group.py b/linghe/quant/block/group.py deleted file mode 100644 index de5908c..0000000 --- a/linghe/quant/block/group.py +++ /dev/null @@ -1,107 +0,0 @@ -# -*- coding: utf-8 -*- -""" -Copyright (c) Ant Financial Service Group and its affiliates. -""" - -import torch -import triton -import triton.language as tl - - -@triton.jit -def group_quant_kernel(x_ptr, y_ptr, s_ptr, N, BLOCK_SIZE: tl.constexpr, - K: tl.constexpr, ROUND: tl.constexpr): - pid = tl.program_id(axis=0) - offs = pid * N + tl.arange(0, K * BLOCK_SIZE) - n = tl.cdiv(N, K * BLOCK_SIZE) - soffs = pid * n * K + tl.arange(0, K) - for i in range(n): - x = tl.load(x_ptr + offs).to(tl.float32) - x = tl.reshape(x, (K, BLOCK_SIZE), can_reorder=False) - s = tl.maximum(tl.max(tl.abs(x), 1) / 448.0, 1e-30) - if ROUND: - s = tl.exp2(tl.ceil(tl.log2(s))) - y = x / s[:, None] - y = y.to(y_ptr.dtype.element_ty) - y = tl.reshape(y, (K * BLOCK_SIZE,), can_reorder=False) - tl.store(y_ptr + offs, y) - tl.store(s_ptr + soffs, s) - offs += K * BLOCK_SIZE - soffs += K - - -def triton_group_quant(x, dtype=torch.float8_e4m3fn, group_size=128, - round_scale=False): - M, N = x.shape - K = 16 - assert N % group_size == 0 and N % (group_size * K) == 0 - assert x.is_contiguous() - - y = torch.empty((M, N), device=x.device, dtype=dtype) - s = torch.empty(M, N // group_size, device=x.device, dtype=torch.float32) - grid = (M,) # noqa - group_quant_kernel[grid](x, - y, - s, - N, - group_size, - K, - round_scale, - num_stages=5, - num_warps=4) - return y, s - - -@triton.jit -def persist_group_quant_kernel(x_ptr, y_ptr, s_ptr, N, BLOCK_SIZE: tl.constexpr, - B: tl.constexpr, K: tl.constexpr, - ROUND: tl.constexpr): - pid = tl.program_id(axis=0) - offs = pid * B * N + tl.arange(0, B)[:, None] * N + tl.arange(0, - K * BLOCK_SIZE)[ - None, :] - n = tl.cdiv(N, K * BLOCK_SIZE) - soffs = pid * B * n * K + tl.arange(0, B)[:, None] * n * K + tl.arange(0, - K)[ - None, :] - - for j in range(n): - x = tl.load(x_ptr + offs).to(tl.float32) - x = tl.reshape(x, (B, K, BLOCK_SIZE)) - - s = tl.maximum(tl.max(tl.abs(x), 2) / 448.0, 1e-30) - if ROUND: - s = tl.exp2(tl.ceil(tl.log2(s))) - y = x / s[:, :, None] - y = y.to(y_ptr.dtype.element_ty) - y = tl.reshape(y, (B, K * BLOCK_SIZE)) - tl.store(y_ptr + offs, y) - tl.store(s_ptr + soffs, s) - offs += K * BLOCK_SIZE - soffs += K - - -def triton_persist_group_quant(x, dtype=torch.float8_e4m3fn, group_size=128, - round_scale=False): - M, N = x.shape - device = x.device - K = 8 - B = 8 - assert N % group_size == 0 and N % (group_size * K) == 0 - assert x.is_contiguous() - - y = torch.empty((M, N), dtype=dtype, device=device) - s = torch.empty(M, N // group_size, device=x.device, dtype=torch.float32) - - grid = (M // B,) # noqa - persist_group_quant_kernel[grid](x, - y, - s, - N, - group_size, - B, - K, - round_scale, - num_stages=3, - num_warps=8) - return y, s diff --git a/linghe/quant/channel/channel.py b/linghe/quant/channel.py similarity index 88% rename from linghe/quant/channel/channel.py rename to linghe/quant/channel.py index 4968747..01a9571 100644 --- a/linghe/quant/channel/channel.py +++ b/linghe/quant/channel.py @@ -3,6 +3,7 @@ Copyright (c) Ant Financial Service Group and its affiliates. """ +from typing import Optional import torch import triton import triton.language as tl @@ -36,6 +37,16 @@ def row_quant_kernel(x_ptr, q_ptr, s_ptr, M, N, BLOCK_SIZE: tl.constexpr, def triton_row_quant(x, round_scale=False): + """ + rowwise quantize x + Args: + x: input x + round_scale: whether round scale to power of 2 + + Returns: + x_q: quantized tensor + x_scale: quantization scale + """ M, N = x.shape BLOCK_SIZE = max([N % x == 0 for x in [512, 1024, 2048, 4096, 8192]]) x_q = torch.empty((M, N), dtype=torch.float8_e4m3fn, device=x.device) @@ -73,9 +84,10 @@ def deprecated_tokenwise_row_quant_kernel(x_ptr, out_ptr, scale_ptr, M, offs += N -def triton_deprecated_tokenwise_row_quant(x, out=None, scale=None, - round_scale=False): - # row-wise read, row-wise write +def triton_deprecated_tokenwise_row_quant(x: torch.Tensor, + out: Optional[torch.Tensor] = None, + scale: Optional[torch.Tensor] = None, + round_scale: bool = False): M, N = x.shape device = x.device if out is None: @@ -113,6 +125,16 @@ def tokenwise_row_quant_kernel(x_ptr, out_ptr, scale_ptr, N: tl.constexpr, def triton_tokenwise_row_quant(x, out=None, scale=None, round_scale=False): + """ + rowwise quantize x with power of 2 dim size + Args: + x: input x + round_scale: whether round scale to power of 2 + + Returns: + out: quantized tensor + scale: quantization scale + """ # row-wise read, row-wise write M, N = x.shape device = x.device @@ -169,7 +191,18 @@ def transpose_row_quant_kernel(x_ptr, q_ptr, s_ptr, M, N, H: tl.constexpr, toffs += H -def triton_transpose_row_quant(x, side=0, round_scale=False): +def triton_transpose_row_quant(x, round_scale=False): + """ + transpose x and row quantize x + Args: + x: input x + round_scale: whether round scale to power of 2 + + Returns: + x_q: quantized tensor + x_scale: quantization scale + + """ M, N = x.shape H = 1024 W = 16 @@ -218,7 +251,6 @@ def channel_quant_forward(x, w): def channel_quant_backward(y, w): y_q, y_scale, w_q, w_scale = triton_channel_quant_nn(y, w) - # print(f'{y.shape=} {w.shape=} {y_q.shape=} {y_scale.shape=} {w_q.shape=} {w_scale.shape=}') output = torch._scaled_mm(y_q, w_q.t(), scale_a=y_scale, diff --git a/linghe/quant/channel/__init__.py b/linghe/quant/channel/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/linghe/quant/group.py b/linghe/quant/group.py new file mode 100644 index 0000000..9dec9b8 --- /dev/null +++ b/linghe/quant/group.py @@ -0,0 +1,66 @@ +# -*- coding: utf-8 -*- +""" +Copyright (c) Ant Financial Service Group and its affiliates. +""" + +import torch +import triton +import triton.language as tl + + +@triton.jit +def group_quant_kernel(x_ptr, y_ptr, s_ptr, N, BLOCK_SIZE: tl.constexpr, + K: tl.constexpr, ROUND: tl.constexpr): + pid = tl.program_id(axis=0) + offs = pid * N + tl.arange(0, K * BLOCK_SIZE) + n = tl.cdiv(N, K * BLOCK_SIZE) + soffs = pid * n * K + tl.arange(0, K) + for i in range(n): + x = tl.load(x_ptr + offs).to(tl.float32) + x = tl.reshape(x, (K, BLOCK_SIZE), can_reorder=False) + s = tl.maximum(tl.max(tl.abs(x), 1) / 448.0, 1e-30) + if ROUND: + s = tl.exp2(tl.ceil(tl.log2(s))) + y = x / s[:, None] + y = y.to(y_ptr.dtype.element_ty) + y = tl.reshape(y, (K * BLOCK_SIZE,), can_reorder=False) + tl.store(y_ptr + offs, y) + tl.store(s_ptr + soffs, s) + offs += K * BLOCK_SIZE + soffs += K + + +def triton_group_quant(x, + dtype=torch.float8_e4m3fn, + group_size=128, + round_scale=False): + """ + groupwise quantize x, group is in under rowwise format + Args: + x: input tensor + group_size: group wise + round_scale: whether round scale to power of 2 + + Returns: + y: quantized tensor, float8_e4m3fn + s: quantization scale, float32 + """ + M, N = x.shape + K = 16 + assert N % group_size == 0 and N % (group_size * K) == 0 + assert x.is_contiguous() + + y = torch.empty((M, N), device=x.device, dtype=dtype) + s = torch.empty(M, N // group_size, device=x.device, dtype=torch.float32) + grid = (M,) # noqa + group_quant_kernel[grid](x, + y, + s, + N, + group_size, + K, + round_scale, + num_stages=5, + num_warps=4) + return y, s + diff --git a/linghe/quant/hadamard.py b/linghe/quant/hadamard.py new file mode 100644 index 0000000..1bd77bc --- /dev/null +++ b/linghe/quant/hadamard.py @@ -0,0 +1,170 @@ +# -*- coding: utf-8 -*- +""" +Copyright (c) Ant Financial Service Group and its affiliates. +""" + +import torch +import triton +import triton.language as tl + + +@triton.jit +def hadamard_quant_row_kernel( + x_ptr, + hm_ptr, + x_q_ptr, + x_scale_ptr, + M, + N, + BLOCK_SIZE: tl.constexpr, + R: tl.constexpr, +): + pid = tl.program_id(0) + row_start = pid * R * BLOCK_SIZE + rows = row_start + tl.arange(0, R * BLOCK_SIZE) + mask_rows = rows < M + + hm = tl.load( + hm_ptr + tl.arange(0, BLOCK_SIZE)[:, None] * BLOCK_SIZE + tl.arange(0, + BLOCK_SIZE)[ + None, :]) + + max_val = tl.zeros((R * BLOCK_SIZE,), dtype=tl.float32) + 1.17e-38 + + num_col_blocks = tl.cdiv(N, BLOCK_SIZE) + for col_block in range(num_col_blocks): + col_start = col_block * BLOCK_SIZE + cols = col_start + tl.arange(0, BLOCK_SIZE) + mask_cols = cols < N + + offs = rows[:, None] * N + cols[None, :] + x = tl.load(x_ptr + offs, mask=mask_rows[:, None] & mask_cols[None, :], + other=0.0) + x_transformed = tl.dot(x, hm) + current_max = tl.max(tl.abs(x_transformed), axis=1) + max_val = tl.maximum(max_val, current_max) + + scale = max_val / 448.0 + tl.store(x_scale_ptr + rows, scale, mask=mask_rows) + s = 448.0 / tl.where(max_val > 0, max_val, 1.0) + + for col_block in range(num_col_blocks): + col_start = col_block * BLOCK_SIZE + cols = col_start + tl.arange(0, BLOCK_SIZE) + mask_cols = cols < N + + offs = rows[:, None] * N + cols[None, :] + x = tl.load(x_ptr + offs, mask=mask_rows[:, None] & mask_cols[None, :], + other=0.0) + x_transformed = tl.dot(x, hm) + quantized = (x_transformed * s[:, None]).to(x_q_ptr.dtype.element_ty) + tl.store(x_q_ptr + offs, quantized, + mask=mask_rows[:, None] & mask_cols[None, :]) + + +@triton.jit +def hadamard_quant_col_kernel( + x_ptr, + hm_ptr, + xt_q_ptr, + xt_scale_ptr, + M, + N, + BLOCK_SIZE: tl.constexpr, + R: tl.constexpr, +): + pid = tl.program_id(0) + col_start = pid * R * BLOCK_SIZE + cols = col_start + tl.arange(0, R * BLOCK_SIZE) + mask_cols = cols < N + + hm = tl.load( + hm_ptr + tl.arange(0, BLOCK_SIZE)[:, None] * BLOCK_SIZE + tl.arange(0, + BLOCK_SIZE)[ + None, :]) + + max_val = tl.zeros((R * BLOCK_SIZE,), dtype=tl.float32) + 1.17e-38 + + num_row_blocks = tl.cdiv(M, BLOCK_SIZE) + for row_block in range(num_row_blocks): + row_start = row_block * BLOCK_SIZE + rows = row_start + tl.arange(0, BLOCK_SIZE) + mask_rows = rows < M + + offs = rows[:, None] * N + cols[None, :] + x = tl.load(x_ptr + offs, mask=mask_rows[:, None] & mask_cols[None, :], + other=0.0) + x_transformed = tl.dot(hm, x) + current_max = tl.max(tl.abs(x_transformed), axis=0) + max_val = tl.maximum(max_val, current_max) + + scale = max_val / 448.0 + tl.store(xt_scale_ptr + cols, scale, mask=mask_cols) + s = 448.0 / tl.where(max_val > 0, max_val, 1.0) + + for row_block in range(num_row_blocks): + row_start = row_block * BLOCK_SIZE + rows = row_start + tl.arange(0, BLOCK_SIZE) + mask_rows = rows < M + + offs = rows[:, None] * N + cols[None, :] + x = tl.load(x_ptr + offs, mask=mask_rows[:, None] & mask_cols[None, :], + other=0.0) + x_transformed = tl.dot(hm, x) + quantized = (x_transformed * s[None, :]).to(xt_q_ptr.dtype.element_ty) + quantized_t = tl.trans(quantized) + store_offs = cols[:, None] * M + rows[None, :] + tl.store(xt_q_ptr + store_offs, quantized_t, + mask=mask_cols[:, None] & mask_rows[None, :]) + + +def triton_hadamard_quant(x, hm): + """ + apply hadamard transformation and then quantize transformed tensor + Args: + x: input tensor + hm: hamadard matrix + Returns: + x_q: rowwise quantized tensor of non-transposed x + x_scale: rowwise quantization scale of non-transposed x + xt_q: columnwise quantized tensor of transposed x + xt_scale: columnwise quantization scale of transposed x + """ + M, N = x.shape + device = x.device + BLOCK_SIZE = hm.size(0) + R = 1 + x_q = torch.empty((M, N), dtype=torch.float8_e4m3fn, device=device) + xt_q = torch.empty((N, M), dtype=torch.float8_e4m3fn, device=device) + x_scale = torch.empty((M, ), dtype=torch.float32, device=device) + xt_scale = torch.empty((N, ), dtype=torch.float32, device=device) + + grid_row = (triton.cdiv(M, R * BLOCK_SIZE),) + hadamard_quant_row_kernel[grid_row]( + x, + hm, + x_q, + x_scale, + M, + N, + BLOCK_SIZE, + R, + num_stages=6, + num_warps=4 + ) + + grid_col = (triton.cdiv(N, R * BLOCK_SIZE),) + hadamard_quant_col_kernel[grid_col]( + x, + hm, + xt_q, + xt_scale, + M, + N, + BLOCK_SIZE, + R, + num_stages=6, + num_warps=4 + ) + + return x_q, x_scale,xt_q, xt_scale diff --git a/linghe/quant/smooth.py b/linghe/quant/smooth.py new file mode 100644 index 0000000..4844511 --- /dev/null +++ b/linghe/quant/smooth.py @@ -0,0 +1,988 @@ +# -*- coding: utf-8 -*- +""" +Copyright (c) Ant Financial Service Group and its affiliates. +""" + +import torch +import triton +import triton.language as tl + +from linghe.tools.util import round_up +from linghe.utils.transpose import triton_transpose_and_pad + + +@triton.jit +def tokenwise_smooth_quant_kernel(x_ptr, q_ptr, ss_ptr, qs_ptr, max_ptr, + M, T, + N: tl.constexpr, + W: tl.constexpr, + EVEN: tl.constexpr, + REVERSE: tl.constexpr, + ROUND: tl.constexpr, + CALIBRATE: tl.constexpr): + pid = tl.program_id(axis=0) + # row-wise read, row-wise write + smooth_scale = tl.load(ss_ptr + tl.arange(0, N))[None, :] + if not REVERSE: + smooth_scale = 1.0 / smooth_scale + + if CALIBRATE: + output_maxs = tl.zeros((W, N), dtype=tl.float32) + for i in range(T): + indices = pid * W * T + i * W + tl.arange(0, W) + if EVEN: + x = tl.load(x_ptr + pid * W * T * N + i * N * W + tl.arange(0, W)[:, + None] * N + tl.arange( + 0, N)[None, :]).to( + tl.float32) + else: + x = tl.load(x_ptr + pid * W * T * N + i * N * W + tl.arange(0, W)[:, + None] * N + tl.arange( + 0, N)[None, :], + mask=indices[:, None] < M).to( + tl.float32) + if CALIBRATE: + output_maxs = tl.maximum(tl.abs(x), output_maxs) + x *= smooth_scale + x_max = tl.max(tl.abs(x), axis=1) + scale = tl.maximum(x_max / 448.0, 1e-30) + if ROUND: + scale = tl.exp2(tl.ceil(tl.log2(scale))) + if EVEN: + tl.store(qs_ptr + pid * W * T + i * W + tl.arange(0, W), scale, ) + else: + tl.store(qs_ptr + pid * W * T + i * W + tl.arange(0, W), scale, + mask=indices < M) + + x /= scale[:, None] + xq = x.to(q_ptr.dtype.element_ty) + if EVEN: + tl.store(q_ptr + pid * W * T * N + i * N * W + tl.arange(0, W)[:, + None] * N + tl.arange( + 0, + N)[ + None, :], + xq) + else: + tl.store(q_ptr + pid * W * T * N + i * N * W + tl.arange(0, W)[:, + None] * N + tl.arange( + 0, + N)[ + None, :], + xq, + mask=indices[:, None] < M) + if CALIBRATE: + output_maxs = tl.max(output_maxs, 0) + tl.store(max_ptr + pid * N + tl.arange(0, N), output_maxs) + + +@triton.jit +def blockwise_smooth_quant_kernel(x_ptr, q_ptr, ss_ptr, qs_ptr, max_ptr, + M, + N, + H: tl.constexpr, + W: tl.constexpr, + EVEN: tl.constexpr, + REVERSE: tl.constexpr, + ROUND: tl.constexpr, + CALIBRATE: tl.constexpr): + pid = tl.program_id(axis=0) + # row-wise read, row-wise write + offs = pid * W * N + tl.arange(0, W)[:, None] * N + tl.arange(0, H)[None, :] + soffs = tl.arange(0, H) + x_max = tl.zeros((W,), dtype=tl.float32) + n = tl.cdiv(N, H) + for i in range(n): + smooth_scale = tl.load(ss_ptr + soffs) + if EVEN: + x = tl.load(x_ptr + offs).to(tl.float32) + else: + x = tl.load(x_ptr + offs, + mask=pid * W + tl.arange(0, W)[:, None] < M).to( + tl.float32) + if CALIBRATE: + output_maxs = tl.max(x.abs(), 0) + tl.store(max_ptr + pid * N + i * H + tl.arange(0, H), output_maxs) + if REVERSE: + x = x * smooth_scale + else: + x = x / smooth_scale + x_max = tl.maximum(tl.max(tl.abs(x), axis=1), x_max) + offs += H + soffs += H + + scale = tl.maximum(x_max / 448, 1e-30) + if ROUND: + scale = tl.exp2(tl.ceil(tl.log2(scale))) + + tl.store(qs_ptr + pid * W + tl.arange(0, W), scale, + mask=pid * W + tl.arange(0, W) < M) + + s = (1.0 / scale)[:, None] + + offs = pid * W * N + tl.arange(0, W)[:, None] * N + tl.arange(0, H)[None, :] + soffs = tl.arange(0, H) + for i in range(n): + smooth_scale = tl.load(ss_ptr + soffs) + if EVEN: + x = tl.load(x_ptr + offs) + else: + x = tl.load(x_ptr + offs, + mask=pid * W + tl.arange(0, W)[:, None] < M) + + if REVERSE: + xq = (x.to(tl.float32) * smooth_scale * s).to( + q_ptr.dtype.element_ty) + else: + xq = (x.to(tl.float32) / smooth_scale * s).to( + q_ptr.dtype.element_ty) + + if EVEN: + tl.store(q_ptr + offs, xq) + else: + # tl.store(q_ptr+offs, xq, mask=(i*H+tl.arange(0, H)[None,:] 8192 else 4 + EVEN = M % W == 0 + T = triton.cdiv(M, W) + if calibrate: + x_maxs = torch.empty((T, N), device=device, dtype=torch.bfloat16) + else: + x_maxs = None + grid = (T,) + blockwise_smooth_quant_kernel[grid]( + x, + x_q, + smooth_scale, + x_scale, + x_maxs, + M, + N, + H, + W, + EVEN, + reverse, + round_scale, + calibrate, + num_stages=3, + num_warps=4 + ) + if calibrate: + x_maxs = x_maxs.amax(0).float() + + return x_q, x_scale, x_maxs + + +@triton.jit +def subrow_smooth_quant_kernel(x_ptr, q_ptr, ss_ptr, qs_ptr, + subrow_scales_ptr, + tail_ri, + tail_si, + head_ri, + head_ei, + size, + N, + W: tl.constexpr, + TAIL: tl.constexpr, + HEAD: tl.constexpr, + REVERSE: tl.constexpr, + ROUND: tl.constexpr): + if TAIL: + # scale is saved as max/448 + scale = tl.maximum(tl.load(subrow_scales_ptr), 1e-30) + if ROUND: + scale = tl.exp2(tl.ceil(tl.log2(scale))) + # scale only stores in subrow with leading values + + T = tl.cdiv(N - tail_si, W) + for i in range(T): + mask = tail_si + i * W + tl.arange(0, W) < N + if REVERSE: + smooth_scale = tl.load( + ss_ptr + tail_si + i * W + tl.arange(0, W), mask=mask) + else: + smooth_scale = tl.load( + ss_ptr + tail_si + i * W + tl.arange(0, W), other=1e30, + mask=mask) + smooth_scale = 1.0 / smooth_scale + x = tl.load(x_ptr + i * W + tl.arange(0, W), mask=mask).to( + tl.float32) + x *= smooth_scale + x /= scale + xq = tl.minimum(tl.maximum(x, -448), 448) + tl.store(q_ptr + tail_ri * N + tail_si + i * W + tl.arange(0, W), + xq.to(q_ptr.dtype.element_ty), mask=mask) + + if HEAD: + # scale is saved as max/448 + scale = tl.maximum(tl.load(subrow_scales_ptr + 1), 1e-30) + if ROUND: + scale = tl.exp2(tl.ceil(tl.log2(scale))) + tl.store(qs_ptr + head_ri, scale) + + T = tl.cdiv(head_ei, W) + for i in range(T): + mask = i * W + tl.arange(0, W) < head_ei + if REVERSE: + smooth_scale = tl.load(ss_ptr + i * W + tl.arange(0, W), + mask=mask) + else: + smooth_scale = tl.load(ss_ptr + i * W + tl.arange(0, W), + other=1e30, mask=mask) + smooth_scale = 1.0 / smooth_scale + x = tl.load(x_ptr + size - head_ei + i * W + tl.arange(0, W), + mask=mask).to(tl.float32) + x *= smooth_scale + x /= scale + xq = tl.minimum(tl.maximum(x, -448), 448) + tl.store(q_ptr + head_ri * N + i * W + tl.arange(0, W), + xq.to(q_ptr.dtype.element_ty), mask=mask) + + +def triton_subrow_smooth_quant(x, smooth_scale, x_q, x_scale, + subrow_scales, offset, size, + reverse=False, round_scale=False): + """ + + """ + M, N = x_q.shape + W = 128 + if offset % N == 0: + tail_ri = 0 + tail_si = 0 + TAIL = False + else: + tail_ri = offset // N + tail_si = offset % N + TAIL = True + + if (offset + size) % N == 0: + head_ri = 0 + head_ei = 0 # head_size = head_ei + HEAD = False + else: + head_ri = (offset + size) // N + head_ei = (offset + size) % N + HEAD = True + + grid = (1,) + subrow_smooth_quant_kernel[grid]( + x, + x_q, + smooth_scale, + x_scale, + subrow_scales, + tail_ri, + tail_si, + head_ri, + head_ei, + size, + N, + W, + TAIL, + HEAD, + reverse, + round_scale, + num_stages=3, + num_warps=1 + ) + + +@triton.jit +def depracated_tokenwise_smooth_quant_kernel(x_ptr, q_ptr, ss_ptr, + qs_ptr, M, W, + N: tl.constexpr, + REVERSE: tl.constexpr, + ROUND: tl.constexpr): + pid = tl.program_id(axis=0) + # row-wise read, row-wise write + smooth_scale = tl.load(ss_ptr + tl.arange(0, N)) + if not REVERSE: + smooth_scale = 1.0 / smooth_scale + + for i in range(W): + x = tl.load(x_ptr + pid * W * N + i * N + tl.arange(0, N), + mask=pid * W + i < M).to(tl.float32) + x *= smooth_scale + x_max = tl.maximum(tl.max(tl.abs(x)), 1e-30) + + scale = x_max / 448.0 + if ROUND: + scale = tl.exp2(tl.ceil(tl.log2(scale))) + tl.store(qs_ptr + pid * W + i, scale, mask=pid * W + i < M) + + x /= scale + xq = x.to(q_ptr.dtype.element_ty) + tl.store(q_ptr + pid * W * N + i * N + tl.arange(0, N), xq, + mask=pid * W + i < M) + + +def triton_depracated_tokenwise_smooth_quant(x, smooth_scale, x_q=None, + x_scale=None, reverse=False, + round_scale=False): + """ + + """ + # row-wise read, row-wise write + M, N = x.shape + device = x.device + if x_q is None: + x_q = torch.empty((M, N), device=device, dtype=torch.float8_e4m3fn) + if x_scale is None: + x_scale = torch.empty((M,), device=device, dtype=torch.float32) + sm = torch.cuda.get_device_properties(device).multi_processor_count + W = triton.cdiv(M, sm) + grid = (sm,) + depracated_tokenwise_smooth_quant_kernel[grid]( + x, + x_q, + smooth_scale, + x_scale, + M, + W, + N, + reverse, + round_scale, + num_stages=3, + num_warps=8 + ) + return x_q, x_scale + + +@triton.jit +def batch_smooth_quant_kernel(x_ptr, q_ptr, ss_ptr, qs_ptr, xm_ptr, count_ptr, + accum_ptr, T, N: tl.constexpr, + REVERSE: tl.constexpr, ROUND: tl.constexpr, + CALIBRATE: tl.constexpr): + pid = tl.program_id(axis=0) + + i_expert = pid // T + i_batch = pid % T + + # row-wise read, row-wise write + smooth_scale = tl.load(ss_ptr + i_expert * N + tl.arange(0, N)) + if not REVERSE: + smooth_scale = 1.0 / smooth_scale + + if CALIBRATE: + x_maxs = tl.zeros((N,), dtype=tl.float32) + + count = tl.load(count_ptr + i_expert) + ei = tl.load(accum_ptr + i_expert) + si = ei - count + + n = tl.cdiv(count, T) # samples for each task + for i in range(i_batch * n, min((i_batch + 1) * n, count)): + x = tl.load(x_ptr + si * N + i * N + tl.arange(0, N)).to(tl.float32) + if CALIBRATE: + x_maxs = tl.maximum(x_maxs, x.abs()) + x *= smooth_scale + scale = tl.maximum(tl.max(tl.abs(x)) / 448.0, 1e-30) + if ROUND: + scale = tl.exp2(tl.ceil(tl.log2(scale))) + + tl.store(qs_ptr + si + i, scale) + + s = 1.0 / scale + x *= s + xq = x.to(q_ptr.dtype.element_ty) + tl.store(q_ptr + si * N + i * N + tl.arange(0, N), xq) + + if CALIBRATE: + tl.store(xm_ptr + pid * N + tl.arange(0, N), x_maxs) + + +""" +select and smooth and quant +x: [bs, dim] +smooth_scales: [n_experts, dim] +token_count_per_expert: [n_experts] +x_q: [bs, dim] +x_scale: [bs] +""" + + +def triton_batch_smooth_quant(x, smooth_scales, token_count_per_expert, + x_q=None, x_scale=None, x_maxs=None, + reverse=False, round_scale=False, + calibrate=False): + """ + + """ + M, N = x.shape + device = x.device + n_expert = token_count_per_expert.shape[0] + assert 128 % n_expert == 0 + if x_q is None: + x_q = torch.empty((M, N), device=device, dtype=torch.float8_e4m3fn) + if x_scale is None: + x_scale = torch.empty((M,), device=device, dtype=torch.float32) + accum_token_count = torch.cumsum(token_count_per_expert, 0) + T = 128 // n_expert + if calibrate and x_maxs is None: + x_maxs = torch.empty((128, N), device=device, dtype=torch.float32) + + grid = (128,) + batch_smooth_quant_kernel[grid]( + x, + x_q, + smooth_scales, + x_scale, + x_maxs, + token_count_per_expert, + accum_token_count, + T, N, + reverse, + round_scale, + calibrate, + num_stages=3, + num_warps=8 + ) + if calibrate: + x_maxs = x_maxs.view(n_expert, T, N).amax(1) + return x_q, x_scale, x_maxs + + +@triton.jit +def batch_pad_transpose_smooth_quant_kernel(x_ptr, q_ptr, ss_ptr, qs_ptr, + count_ptr, + accum_ptr, + N, + H: tl.constexpr, + W: tl.constexpr, + E: tl.constexpr, + REVERSE: tl.constexpr, + ROUND: tl.constexpr): + eid = tl.program_id(axis=0) + bid = tl.program_id(axis=1) + + count = tl.load(count_ptr + eid) + ei = tl.load(accum_ptr + eid) + si = ei - count + round_count = tl.cdiv(count, 32) * 32 + + counts = tl.load(count_ptr + tl.arange(0, E)) + n_blocks = tl.cdiv(counts, 128) + bias = tl.sum(tl.where(tl.arange(0, E) < eid, n_blocks, 0)) + + n = tl.cdiv(count, H) + maxs = tl.zeros((H, W), dtype=tl.float32) + for i in range(n): + # col-wise read, row-wise write + indices = i * H + tl.arange(0, H) + smooth_scale = tl.load(ss_ptr + indices, mask=indices < count) + if not REVERSE: + smooth_scale = 1.0 / smooth_scale + + x = tl.load(x_ptr + si * N + i * H * N + bid * W + tl.arange(0, H)[:, + None] + tl.arange(0, + W)[ + None, :], + mask=indices[:, None] < count).to(tl.float32) + x *= smooth_scale[:, None] + maxs = tl.maximum(maxs, tl.abs(x)) + + maxs = tl.max(maxs, 0) + scale = tl.maximum(tl.max(maxs, 0) / 448.0, 1e-30) + if ROUND: + scale = tl.exp2(tl.ceil(tl.log2(scale))) + tl.store(qs_ptr + eid * N + bid * W + tl.arange(0, W), scale) + s = 1.0 / scale + + for i in range(n): + # col-wise read, row-wise write + indices = i * H + tl.arange(0, H) + smooth_scale = tl.load(ss_ptr + indices, mask=indices < count) + if not REVERSE: + smooth_scale = 1.0 / smooth_scale + + x = tl.load(x_ptr + si * N + i * H * N + bid * W + tl.arange(0, H)[:, + None] + tl.arange(0, + W)[ + None, :], + mask=indices[:, None] < count).to(tl.float32) + x *= smooth_scale[:, None] + x *= s + xq = tl.trans(x.to(q_ptr.dtype.element_ty)) + tl.store( + q_ptr + bias * N + bid * W * round_count + i * H + tl.arange(0, W)[ + :, + None] + tl.arange( + 0, H)[None, :], xq, mask=indices[None, :] < round_count) + + +""" +used in silu backward +pad to multiple of 32 and transpose and smooth quant +x: [sum(token_per_expert), dim] +smooth_scales: [sum(token_per_expert)] +token_count_per_expert: [n_experts] +splits: list of token_count_per_expert +x_q: [sum(roundup(token_per_expert)) * dim] +x_scale: [n_experts, dim] +""" + + +def triton_batch_pad_transpose_smooth_quant(x, + smooth_scales, + token_count_per_expert, + splits, + x_q=None, x_scale=None, x_maxs=None, + reverse=False, round_scale=False): + """ + + """ + M, N = x.shape + device = x.device + n_expert = token_count_per_expert.shape[0] + round_splits = [(x + 31) // 32 * 32 for x in splits] + round_size = sum(round_splits) + if x_q is None: + x_q = torch.empty((round_size, N), device=device, + dtype=torch.float8_e4m3fn) + if x_scale is None: + x_scale = torch.empty((n_expert, N), device=device, dtype=torch.float32) + accum_token_count = torch.cumsum(token_count_per_expert, 0) + H = 128 + W = 32 + grid = (n_expert, N // W) + batch_pad_transpose_smooth_quant_kernel[grid]( + x, + x_q, + smooth_scales, + x_scale, + token_count_per_expert, + accum_token_count, + N, + H, + W, + n_expert, + reverse, + round_scale, + num_stages=3, + num_warps=8 + ) + return x_q, x_scale + + +@triton.jit +def transpose_smooth_quant_kernel(x_ptr, q_ptr, ss_ptr, qs_ptr, M, N, P, + H: tl.constexpr, W: tl.constexpr, + EVEN: tl.constexpr, + REVERSE: tl.constexpr, + ROUND: tl.constexpr): + pid = tl.program_id(axis=0) + # col-wise read, row-wise write + offs = pid * W + tl.arange(0, H)[:, None] * N + tl.arange(0, W)[None, :] + soffs = tl.arange(0, H) + x_max = tl.zeros((W,), dtype=tl.float32) + m = tl.cdiv(P, H) + for i in range(m): + if EVEN: + x = tl.load(x_ptr + offs) + smooth_scale = tl.load(ss_ptr + soffs)[:, None] + else: + x = tl.load(x_ptr + offs, + mask=(i * H + tl.arange(0, H)[:, None] < M) & ( + pid * W + tl.arange(0, W)[None, :] < N)) + other = 0.0 if REVERSE else 1e30 + smooth_scale = tl.load(ss_ptr + soffs, mask=soffs < M, other=other)[ + :, None] + if REVERSE: + x = x * smooth_scale + else: + x = x / smooth_scale + x_max = tl.maximum(tl.max(tl.abs(x), axis=0), x_max) + offs += H * N + soffs += H + + scale = tl.maximum(x_max / 448.0, 1e-30) + if ROUND: + scale = tl.exp2(tl.ceil(tl.log2(scale))) + + if EVEN: + tl.store(qs_ptr + pid * W + tl.arange(0, W), scale) + else: + tl.store(qs_ptr + pid * W + tl.arange(0, W), scale, + mask=pid * W + tl.arange(0, W) < N) + + s = (1.0 / scale)[None, :] + offs = pid * W + tl.arange(0, H)[:, None] * N + tl.arange(0, W)[None, :] + soffs = tl.arange(0, H) + toffs = pid * W * P + tl.arange(0, W)[:, None] * P + tl.arange(0, H)[None, + :] + for i in range(m): + if EVEN: + x = tl.load(x_ptr + offs).to(tl.float32) + smooth_scale = tl.load(ss_ptr + soffs)[:, None] + else: + x = tl.load(x_ptr + offs, + mask=(i * H + tl.arange(0, H)[:, None] < M)).to( + tl.float32) + other = 0.0 if REVERSE else 1e30 + smooth_scale = tl.load(ss_ptr + soffs, mask=soffs < M, other=other)[ + :, None] + + if REVERSE: + x = (x * smooth_scale * s).to(q_ptr.dtype.element_ty) + else: + x = (x / smooth_scale * s).to(q_ptr.dtype.element_ty) + if EVEN: + tl.store(q_ptr + toffs, tl.trans(x)) + else: + # mask with P instead of M + tl.store(q_ptr + toffs, tl.trans(x), + mask=(i * H + tl.arange(0, H)[None, :] < P)) + offs += H * N + toffs += H + soffs += H + + +def triton_transpose_smooth_quant(x, + smooth_scale, + reverse=False, + pad=False, + round_scale=False): + # col-wise read, row-wise write + # M should be padded if M % 32 != 0 + """ + + """ + M, N = x.shape + device = x.device + P = (M + 31) // 32 * 32 if pad else M + x_q = torch.empty((N, P), device=device, dtype=torch.float8_e4m3fn) + x_scale = torch.empty((N,), device=device, dtype=torch.float32) + H = 1024 + W = 16 # if N >= 4096 else 16 + assert N % W == 0 + EVEN = P % H == 0 and M == P + + grid = (triton.cdiv(N, W),) + transpose_smooth_quant_kernel[grid]( + x, + x_q, + smooth_scale, + x_scale, + M, + N, + P, + H, + W, + EVEN, + reverse, + round_scale, + num_stages=3, + num_warps=4 if N >= 8192 else 4 + ) + return x_q, x_scale + + +@triton.jit +def transpose_rescale_smooth_quant_kernel(x_ptr, q_ptr, + org_smooth_scale_ptr, + org_quant_scale_ptr, + transpose_smooth_scale_ptr, + transpose_quant_scale_ptr, M, + N, P, H: tl.constexpr, + W: tl.constexpr, + EVEN: tl.constexpr, + ROUND: tl.constexpr): + pid = tl.program_id(axis=0) + # col-wise read, row-wise write + offs = pid * W + tl.arange(0, H)[:, None] * N + tl.arange(0, W)[None, :] + soffs = tl.arange(0, H) + x_max = tl.zeros((W,), dtype=tl.float32) + org_smooth_scale = tl.load( + org_smooth_scale_ptr + pid * W + tl.arange(0, W))[None, :] + + m = tl.cdiv(P, H) + for i in range(m): + if EVEN: + x = tl.load(x_ptr + offs).to(tl.float32) + org_quant_scale = tl.load(org_quant_scale_ptr + soffs)[:, None] + transpose_smooth_scale = tl.load( + transpose_smooth_scale_ptr + soffs)[:, None] + else: + x = tl.load(x_ptr + offs, + mask=(i * H + tl.arange(0, H)[:, None] < M)).to( + tl.float32) + org_quant_scale = tl.load(org_quant_scale_ptr + soffs, + mask=soffs < M, other=0.0)[:, None] + transpose_smooth_scale = tl.load(transpose_smooth_scale_ptr + soffs, + mask=soffs < M, other=0.0)[:, None] + + x = x / org_smooth_scale * (org_quant_scale * transpose_smooth_scale) + x_max = tl.maximum(tl.max(tl.abs(x), axis=0), x_max) + offs += H * N + soffs += H + + scale = tl.maximum(x_max / 448.0, 1e-30) + if ROUND: + scale = tl.exp2(tl.ceil(tl.log2(scale))) + + tl.store(transpose_quant_scale_ptr + pid * W + tl.arange(0, W), scale) + + s = (1.0 / scale)[None, :] + + offs = pid * W + tl.arange(0, H)[:, None] * N + tl.arange(0, W)[None, :] + soffs = tl.arange(0, H) + toffs = pid * W * P + tl.arange(0, W)[:, None] * P + tl.arange(0, H)[None, + :] + for i in range(m): + + if EVEN: + x = tl.load(x_ptr + offs).to(tl.float32) + org_quant_scale = tl.load(org_quant_scale_ptr + soffs)[:, None] + transpose_smooth_scale = tl.load( + transpose_smooth_scale_ptr + soffs)[:, None] + else: + x = tl.load(x_ptr + offs, + mask=(i * H + tl.arange(0, H)[:, None] < M) & ( + pid * W + tl.arange(0, W)[None, :] < N)).to( + tl.float32) + org_quant_scale = tl.load(org_quant_scale_ptr + soffs, + mask=soffs < M, other=0.0)[:, None] + transpose_smooth_scale = tl.load(transpose_smooth_scale_ptr + soffs, + mask=soffs < M, other=0.0)[:, None] + + x = x * s / org_smooth_scale * ( + org_quant_scale * transpose_smooth_scale) + x = tl.trans(x.to(q_ptr.dtype.element_ty)) + if EVEN: + tl.store(q_ptr + toffs, x) + else: + tl.store(q_ptr + toffs, x, + mask=(i * H + tl.arange(0, H)[None, :] < P)) + offs += H * N + toffs += H + soffs += H + + +""" +x_q is colwise smooth and rowwise quant +org_smooth_scale and transpose_smooth_scale is reversed +smooth scale and quant scale should be power of 2 +step: dequant x_q -> apply smooth scale -> quant -> transpose -> pad +implement: x_q/org_smooth_scale*(org_quant_scale*smooth_scale) -> colwise quant and transpose +""" + + +def triton_transpose_rescale_smooth_quant(x_q, org_smooth_scale, + org_quant_scale, + transpose_smooth_scale, + reverse=True, + pad=False, + round_scale=False): + """ + + """ + assert reverse + M, N = x_q.shape + device = x_q.device + P = round_up(M, b=32) if pad else M + xt_q = torch.empty((N, P), device=device, dtype=torch.float8_e4m3fn) + x_scale = torch.empty((N,), device=device, dtype=torch.float32) + H = 256 + W = 16 + assert N % W == 0 + EVEN = P == M and M % H == 0 + + grid = (triton.cdiv(N, W),) + transpose_rescale_smooth_quant_kernel[grid]( + x_q, + xt_q, + org_smooth_scale, + org_quant_scale, + transpose_smooth_scale, + x_scale, + M, N, P, + H, W, + EVEN, + round_scale, + num_stages=4, + num_warps=8 + ) + + return xt_q, x_scale + + + +""" +megatron fp8 training steps: +step 0: init w smooth scale w_smooth +step 1: smooth and quant w after w is updated by optimizer +step 2: in forward step, columnwise smooth x and rowwise quant x, calc y=x@w; + meanwhile, record the columnwise max of x, it is used to update w_smooth +step 3: in dgrad step, columnwise smooth y and rowwise quant y, transpose x, calc dx=y@wT +step 4: in wgrad step, dequant then smooth an then quant y_q to get yt_q, calc dw=yT@x + +alternative (it's not suitable for fp8 combine): +step 4: in wgrad step, rowwise smooth y and columnwise quant y and transpose to get yt_q, calc dw=yT@x + +""" + +""" +divide x by smooth_scale and row-wise quantization +smooth scale is updated by square root of x's column-wise maxs, and set in weight's x_maxs attr + +transpose: transpose quantized x for wgrad +pad: # pad M to be multiplier of 32, including quant scales and transposed x + +""" + + +# y = x @ w +# dx = y @ wT +# dwT = yT @ x +def triton_smooth_quant_input(x, smooth_scale, x_q=None, x_scale=None, xt_q=None, + transpose=True, pad=True, round_scale=False): + """ + + """ + x_q, x_scale, x_maxs = triton_smooth_quant(x, smooth_scale, x_q=x_q, + x_scale=x_scale, reverse=False, + round_scale=round_scale) + + if transpose: + xt_q = triton_transpose_and_pad(x_q, out=xt_q, pad=pad) + else: + xt_q = None + xt_scale = smooth_scale + + return x_q, xt_q, x_scale, xt_scale + + +# y = x @ w +# dx = y @ wT +# dwT = yT @ x +def triton_smooth_quant_gradient(y, + smooth_scale, + transpose_smooth_scale, + reverse=True, + transpose=True, + pad=True, + round_scale=False): + """ + + """ + assert reverse, ("args `smooth_scale` and/or `transpose_smooth_scale` " + "must be in reciprocal format in triton_smooth_quant_grad") + y_q, y_scale, _ = triton_smooth_quant(y, smooth_scale, reverse=True, + round_scale=round_scale) + if transpose: + yt_q, yt_scale = triton_transpose_smooth_quant(y, + transpose_smooth_scale, + reverse=True, + pad=pad, + round_scale=round_scale) + else: + yt_q, yt_scale = None, None + + return y_q, yt_q, y_scale, yt_scale + + +def triton_smooth_quant_weight(w, + smooth_scale, + w_q, + quant_scale, + subrow_scales, offset=0, + round_scale=False): + """ + + """ + assert w.ndim == 1 + assert w_q.size(1) == smooth_scale.size(0) + + size = w.numel() + M, N = w_q.shape + + if size == M * N: + triton_smooth_quant(w.view(M, N), smooth_scale, x_q=w_q, + x_scale=quant_scale, + round_scale=round_scale) + elif offset % N == 0 and size % N == 0: + n_row = size // N + row_id = offset // N + w_q_slice = w_q[row_id:row_id + n_row] + quant_scale_slice = quant_scale[row_id:row_id + n_row] + triton_smooth_quant(w.view(n_row,N), smooth_scale, x_q=w_q_slice, + x_scale=quant_scale_slice, + round_scale=round_scale) + else: + row_si = (offset - 1)//N + 1 + row_ei = (offset + size) // N + col_si = offset % N + col_ei = (offset + size ) % N + n_row = row_ei - row_si + mw_offset = 0 if col_si == 0 else N - col_si + w_q_slice = w_q[row_si:row_ei] + quant_scale_slice = quant_scale[row_si:row_ei] + w_slice = w[mw_offset:mw_offset+n_row*N].view(n_row,N) + triton_smooth_quant(w_slice, + smooth_scale, + x_q=w_q_slice, + x_scale=quant_scale_slice, + round_scale=round_scale) + + # subrow scale is writed by the row with leading master weights + if col_si > 0 or col_ei > 0: + triton_subrow_smooth_quant(w, + smooth_scale, + w_q, + quant_scale, + subrow_scales, + offset, + size, + reverse=False, + round_scale=round_scale) + diff --git a/linghe/tools/util.py b/linghe/tools/util.py index c2486a0..a97768e 100644 --- a/linghe/tools/util.py +++ b/linghe/tools/util.py @@ -4,7 +4,6 @@ """ import math - import torch @@ -80,6 +79,45 @@ def torch_block_quant(w, B=128, dtype=torch.float8_e4m3fn, round_scale=False): return wq, scale +def torch_smooth_quant(x, smooth_scale, reverse=False, round_scale=False): + x = x.float() + x_maxs = x.abs().amax(0) + if reverse: + x_smooth = x * smooth_scale + else: + x_smooth = x / torch.maximum(smooth_scale, + 1e-30 * torch.ones_like(smooth_scale)) + scale = x_smooth.abs().amax(1) / 448 + scale = torch.maximum(scale, 1e-30 * torch.ones_like(scale)) + if round_scale: + scale = torch.exp2(torch.ceil(torch.log2(scale))) + x_q = (x_smooth / scale[:, None]).to(torch.float8_e4m3fn) + return x_q, scale, x_maxs + + +def torch_batch_smooth_quant(xs, smooth_scales, indices, token_count_per_expert, + reverse=False, round_scale=False): + q_refs = [] + scale_refs = [] + s = 0 + for i, c in enumerate(token_count_per_expert): + idx = indices[s:s + c] + y_slice = xs[idx] + if reverse: + y_smooth = y_slice * smooth_scales[i] + else: + y_smooth = y_slice / smooth_scales[i] + scale = y_smooth.abs().amax(1) / 448 + if round_scale: + scale = torch.exp2(torch.ceil(torch.log2(scale))) + q_refs.append((y_smooth / scale[:, None]).to(torch.float8_e4m3fn)) + scale_refs.append(scale) + s += c + q_ref = torch.cat(q_refs, 0) + scale_ref = torch.cat(scale_refs, 0) + return q_ref, scale_ref + + def torch_make_indices(logits, topk=8, bias=-0.01): M, n_experts = logits.shape device = logits.device @@ -105,6 +143,195 @@ def torch_make_indices(logits, topk=8, bias=-0.01): return probs, route_map, token_count_per_expert, indices, row_id_map +# quant with scaling to 448 +def torch_duplex_smooth_tensor_quant(x, w, dtype): + # w:[bs, in] w:[out, in] + x = x.clone() + w = w.clone() + fmax = torch.finfo(dtype).max + x_max = torch.max(torch.abs(x).float(), dim=0, keepdim=True)[0] + w_max = torch.max(torch.abs(w).float(), dim=0, keepdim=True)[0] + scale = (x_max / w_max) ** 0.5 + x_max_ = x_max / scale + w_max_ = w_max * scale + x_scale = x_max_ / fmax + w_scale = w_max_ / fmax + rescale = fmax / torch.maximum(x_max_.max(), w_max_.max()) + x_q = (x * (rescale / scale).to(x.dtype)).to(dtype) + w_q = (w * (scale * rescale).to(x.dtype)).to(dtype) + + return x_q, w_q, scale, rescale + + +def torch_duplex_smooth_quant(x, w, dtype=torch.float8_e4m3fn): + # w:[bs, in] w:[out, in] + x = x.clone() + w = w.clone() + fmax = torch.finfo(dtype).max + x_max = torch.max(torch.abs(x).float(), dim=0, keepdim=True)[0] + w_max = torch.max(torch.abs(w).float(), dim=0, keepdim=True)[0] + maxs = (x_max * w_max) ** 0.5 + x_scale = x_max / maxs + w_scale = w_max / maxs # reciprocal of x_scale + x_smooth = x / x_scale + w_smooth = w / w_scale + x_max = torch.max(torch.abs(x_smooth).float(), dim=1, keepdim=True)[0] + w_max = torch.max(torch.abs(w_smooth).float(), dim=1, keepdim=True)[0] + x_scale = x_max / fmax + w_scale = w_max / fmax + x_q = (x_smooth * (1.0 / x_scale).to(x.dtype)).to(dtype) + w_q = (w_smooth * (1.0 / w_scale).to(x.dtype)).to(dtype) + + return x_q, w_q, x_scale, w_scale + + +def torch_outlier_quant(x, w, dtype): + x = x.clone() + w = w.clone() + fmax = torch.finfo(dtype).max + max_val, max_idx = torch.topk(x.abs().float().max(dim=0)[0], 5) + # print(max_idx) + x_outlier = x[:, max_idx[:4]] + x[:, max_idx[:4]] = 0.0 + x_scale = max_val[-1] / fmax + xq = (x / x_scale.to(x.dtype)).to(dtype) + w_max = w.abs().float().max() + w_scale = w_max / fmax + wq = (w / w_scale.to(x.dtype)).to(dtype) + return xq, wq, x_scale, w_scale, max_idx[:4], x_outlier + + +def make_hadamard_matrix(n, device='cuda:0', dtype=torch.bfloat16, norm=False): + assert 2 ** int(math.log2(n)) == n + m2 = torch.tensor([[1, 1], [1, -1]], device='cpu', dtype=torch.float32) + m = m2 + for i in range(int(math.log2(n)) - 1): + m = torch.kron(m, m2) + if norm: + m = m / n ** 0.5 + return m.to(dtype=dtype, device=device) + + +def torch_hadamard_transform(x, hm, side='right'): + assert side in ('right', 'left') + x = x.clone() + hm = hm.clone() + M, K = x.shape + B = hm.size(0) + xp = torch.reshape(x, (M // B, B, K // B, B)).permute(0, 2, 1, + 3).contiguous() + if side == 'right': + xp = xp @ hm + else: + xp = hm @ xp + xp = xp.permute(0, 2, 1, 3) + xp = torch.reshape(xp, (M, K)) + return xp + + +# token-wise and channel-wise +def torch_channel_quant_f_and_b(x, w, y): + M, K = x.shape + N, K = w.shape + M, N = y.shape + x_scale = x.abs().float().amax(dim=1, keepdim=True) / 448.0 # [M,1] + w_scale = w.abs().float().amax(dim=1, keepdim=True) / 448.0 # [N,1] + xq = (x / x_scale).to(torch.float8_e4m3fn) + wq = (w / w_scale).to(torch.float8_e4m3fn) + o = torch._scaled_mm(xq, + wq.t(), + scale_a=x_scale.view(-1, 1), + scale_b=w_scale.view(1, -1), + out_dtype=torch.bfloat16, + use_fast_accum=True) + + # dx = y @ wT + # absort w quant scale to y + ys = y * w_scale.view(1, N) + y_scale = ys.abs().float().amax(dim=1, keepdim=True) / 448.0 + 1e-9 + yq = (ys / y_scale).to(torch.float8_e4m3fn) + w_dummy_scale = torch.ones((1, K), dtype=torch.float32, device=x.device) + dx = torch._scaled_mm(yq, + wq.t().contiguous().t(), + scale_a=y_scale, + scale_b=w_dummy_scale, + out_dtype=torch.bfloat16, + use_fast_accum=True) + + # dw = yT@x + yt = y.t().contiguous() + yts = yt * x_scale.view(1, M) + yt_scale = yts.abs().float().amax(dim=1, keepdim=True) / 448.0 + 1e-9 + ytq = (yts / yt_scale).to(torch.float8_e4m3fn) + dw = torch._scaled_mm(ytq, + xq.t().contiguous().t(), + scale_a=yt_scale.view(-1, 1), + scale_b=w_dummy_scale, + out_dtype=torch.bfloat16, + use_fast_accum=True) + return xq, wq, yq, ytq, o, dx, dw + + +# smooth and token-wise/channel-wise +def torch_reuse_smooth_quant_f_and_b(x, w, y): + x = x.clone() + w = w.clone() + y = y.clone() + M, K = x.shape + N, K = w.shape + M, N = y.shape + x_smooth_max = torch.amax(torch.abs(x).float(), dim=0, keepdim=True) + w_smooth_max = torch.amax(torch.abs(w).float(), dim=0, keepdim=True) + maxs = (x_smooth_max * w_smooth_max) ** 0.5 + x_smooth_scale = x_smooth_max / maxs # [K, 1] + w_smooth_scale = w_smooth_max / maxs # [K, 1] reciprocal of x_scale + x_smooth = x / x_smooth_scale + w_smooth = w / w_smooth_scale + + x_quant_max = torch.amax(torch.abs(x_smooth).float(), dim=1, keepdim=True) + w_quant_max = torch.amax(torch.abs(w_smooth).float(), dim=1, keepdim=True) + + x_quant_scale = x_quant_max / 448.0 # [M, 1] + w_quant_scale = w_quant_max / 448.0 # [N, 1] + xq = (x_smooth / x_quant_scale).to(torch.float8_e4m3fn) + wq = (w_smooth / w_quant_scale).to(torch.float8_e4m3fn) + + o = torch._scaled_mm(xq, + wq.t(), + scale_a=x_quant_scale.view(-1, 1), + scale_b=w_quant_scale.view(1, -1), + out_dtype=torch.bfloat16, + use_fast_accum=True) + + # print(f'{x_smooth_scale=} {x_quant_scale[:,0]=} {w_quant_scale=}') + + # dx = y @ wT + # absort w quant scale to y + ys = y * w_quant_scale.view(1, N) + y_scale = ys.abs().float().amax(dim=1, keepdim=True) / 448.0 + 1e-9 + yq = (ys / y_scale).to(torch.float8_e4m3fn) + dx = torch._scaled_mm(yq, + wq.t().contiguous().t(), + scale_a=y_scale, + scale_b=w_smooth_scale.view(1, -1), + out_dtype=torch.bfloat16, + use_fast_accum=True) + + # dw = yT@x + yt = y.t().contiguous() # [N, M] + yts = yt * x_quant_scale.view(1, M) + yt_scale = yts.abs().amax(dim=1, keepdim=True) / 448.0 + 1e-9 + ytq = (yts / yt_scale).to(torch.float8_e4m3fn) + dw = torch._scaled_mm(ytq, + xq.t().contiguous().t(), + scale_a=yt_scale.view(-1, 1), + scale_b=x_smooth_scale.view(1, -1), + out_dtype=torch.bfloat16, + use_fast_accum=True) + + return xq, wq, yq, ytq, o, dx, dw + + def fp16_forward(x, w): return x @ w @@ -217,3 +444,45 @@ def read_and_tile(filename, tile=True): f'y.max={y.abs().max().item():.3f} y.mean={y.abs().mean().item():.3f}') return x, w, y + + +def torch_fp16_vector_scaled_mm(x, weight, x_scale, weight_scale): + output = torch._scaled_mm(x, + weight, + scale_a=x_scale, + scale_b=weight_scale, + out_dtype=torch.bfloat16, + use_fast_accum=True) + return output + + +def torch_fp32_vector_scaled_mm(x, weight, x_scale, weight_scale, ones, + out=None): + output = torch._scaled_mm(x, + weight, + scale_a=ones, + scale_b=ones, + out_dtype=torch.float32, + use_fast_accum=True, + out=out) + return output * x_scale * weight_scale + + +def torch_fp16_scaler_scaled_mm(x, weight, x_scale, weight_scale): + output = torch._scaled_mm(x, + weight, + scale_a=x_scale, + scale_b=weight_scale, + out_dtype=torch.bfloat16, + use_fast_accum=True) + return output + + +def torch_fp32_scaler_scaled_mm(x, weight, x_scale, weight_scale): + output = torch._scaled_mm(x, + weight, + scale_a=x_scale, + scale_b=weight_scale, + out_dtype=torch.float32, + use_fast_accum=True) + return output diff --git a/linghe/utils/add.py b/linghe/utils/add.py index 907c852..868e1f0 100644 --- a/linghe/utils/add.py +++ b/linghe/utils/add.py @@ -3,6 +3,8 @@ Copyright (c) Ant Financial Service Group and its affiliates. """ +import torch +from typing import Iterable, Optional, Tuple import triton import triton.language as tl @@ -43,7 +45,17 @@ def inplace_add_kernel(x_ptr, y_ptr, M, N, H: tl.constexpr, W: tl.constexpr, rid * H + tl.arange(0, H)[None, :] < M)) -def triton_inplace_add(x, y, accum=True): +def triton_inplace_add(x: torch.Tensor, y: torch.Tensor, accum : bool = True): + """ + inplace add y to x + Args: + x: Tensor + y: Tensor + accum: x += y if accum=True else x.copy_(y) + + Returns: + updated x + """ N = x.shape[-1] M = x.numel() // N # M, N = x.shape @@ -64,63 +76,3 @@ def triton_inplace_add(x, y, accum=True): num_warps=num_warps ) return x - - -@triton.jit -def block_add_kernel(x_ptr, y_ptr, M, N, H: tl.constexpr, W: tl.constexpr, - EVEN: tl.constexpr, ACCUM: tl.constexpr): - rid = tl.program_id(axis=0) - cid = tl.program_id(axis=1) - offs = rid * H * N + cid * W + tl.arange(0, H)[:, None] * N + tl.arange(0, - W)[ - None, :] - if ACCUM: - if EVEN: - x = tl.load(x_ptr + offs) - y = tl.load(y_ptr + offs).to(tl.float32) - tl.store(x_ptr + offs, x + y) - else: - x = tl.load(x_ptr + offs, - mask=(cid * W + tl.arange(0, W)[None, :] < N) & ( - rid * H + tl.arange(0, H)[:, None] < M)) - y = tl.load(y_ptr + offs, - mask=(cid * W + tl.arange(0, W)[None, :] < N) & ( - rid * H + tl.arange(0, H)[:, None] < M)) - tl.store(x_ptr + offs, x + y, - mask=(cid * W + tl.arange(0, W)[:, None] < N) & ( - rid * H + tl.arange(0, H)[None, :] < M)) - else: - if EVEN: - y = tl.load(y_ptr + offs).to(tl.float32) - tl.store(x_ptr + offs, y) - else: - y = tl.load(y_ptr + offs, - mask=(cid * W + tl.arange(0, W)[None, :] < N) & ( - rid * H + tl.arange(0, H)[:, None] < M)) - tl.store(x_ptr + offs, y, - mask=(cid * W + tl.arange(0, W)[:, None] < N) & ( - rid * H + tl.arange(0, H)[None, :] < M)) - - -def triton_block_add(x, y, accum=True): - shape = x.shape[-1] - N = shape - M = x.numel() // N - # M, N = x.shape - H = 128 - W = 128 - EVEN = M % H == 0 and N % W == 0 - num_stages = 2 - num_warps = 8 - - grid = (triton.cdiv(M, H), triton.cdiv(N, W)) - block_add_kernel[grid]( - x, y, - M, N, - H, W, - EVEN, - accum, - num_stages=num_stages, - num_warps=num_warps - ) - return x diff --git a/linghe/utils/dot.py b/linghe/utils/dot.py index 6ac1ead..f57acab 100644 --- a/linghe/utils/dot.py +++ b/linghe/utils/dot.py @@ -18,14 +18,24 @@ def dot_kernel(x_ptr, y_ptr, sum_ptr, M, N, H: tl.constexpr, W: tl.constexpr): sums = tl.zeros((W,), dtype=tl.float32) for i in range(n): x = tl.load(x_ptr + offs).to(tl.float32) - q = tl.load(y_ptr + offs).to(tl.float32) - sums += tl.sum(x * q, axis=1) + y = tl.load(y_ptr + offs).to(tl.float32) + sums += tl.sum(x * y, axis=1) offs += H tl.store(sum_ptr + pid * W + tl.arange(0, W), sums) def triton_dot(x, y): + """ + vector dot multiply, output = sum(x*y, 1), + it is used to calculate gradient of router weight + Args: + x: + y: + + Returns: + output of sum(x*y, 1) + """ M, N = x.shape H = 128 W = 16 @@ -45,52 +55,3 @@ def triton_dot(x, y): ) return s - -@triton.jit -def mix_precise_dot_kernel(x_ptr, q_ptr, sum_ptr, smooth_scale_ptr, - quant_scale_ptr, M, N, H: tl.constexpr, - W: tl.constexpr): - # rowwise read, rowwise write - pid = tl.program_id(axis=0) - offs = pid * W * N + tl.arange(0, W)[:, None] * N + tl.arange(0, H)[None, :] - soffs = tl.arange(0, H) - quant_scale = tl.load(quant_scale_ptr + pid * W + tl.arange(0, W)) - - n = tl.cdiv(N, H) - sums = tl.zeros((W,), dtype=tl.float32) - for i in range(n): - x = tl.load(x_ptr + offs) - q = tl.load(q_ptr + offs) - smooth_scale = tl.load(smooth_scale_ptr + soffs)[None, :] - q = q.to(tl.float32) * smooth_scale - x = x.to(tl.float32) - sums += tl.sum(x * q, axis=1) * quant_scale - offs += H - soffs += H - - tl.store(sum_ptr + pid * W + tl.arange(0, W), sums) - - -# q should be dequant -def triton_mix_precise_dot(x, q, smooth_scale, quant_scale, reverse=False): - assert reverse - M, N = x.shape - device = x.device - s = torch.empty((M,), device=device, dtype=x.dtype) - - H = 128 - W = 16 - num_stages = 5 - num_warps = 8 - - grid = (triton.cdiv(M, W),) - mix_precise_dot_kernel[grid]( - x, q, s, - smooth_scale, - quant_scale, - M, N, - H, W, - num_stages=num_stages, - num_warps=num_warps - ) - return s diff --git a/linghe/utils/gather.py b/linghe/utils/gather.py index dd29c03..abba7e4 100644 --- a/linghe/utils/gather.py +++ b/linghe/utils/gather.py @@ -55,13 +55,20 @@ def make_row_id_map_kernel(map_ptr, count_ptr, output_ptr, M, B, P, offs += b * E -# """ -# make row id map, shape:[n_tokens, n_experts] -# """ + def triton_make_row_id_map( routing_map: torch.Tensor, multiple_of: int = 1 ): + """ + make row id map, values in the tensor are the row indices + Args: + routing_map: a tensor of 0/1 values, 1 indicates routed + multiple_of: padding the tokens of each expert to multiple of this value + + Returns: + row id map with shape [n_tokens, n_experts] + """ n_tokens, n_experts = routing_map.shape T = 128 block_counts = torch.empty((T, n_experts), dtype=torch.int32, @@ -137,20 +144,21 @@ def make_row_id_map_and_indices_kernel(map_ptr, count_ptr, row_map_ptr, offs += b * E -""" -routing map, shape:[n_tokens, n_experts] -num_out_tokens, shape:[sum(round(bs))] - -row id map, shape:[n_tokens, n_experts] -row id indices, shape: [sum(n_tokens_per_experts)] -""" - - def triton_make_row_id_map_and_indices( routing_map: torch.Tensor, num_out_tokens: int, multiple_of: int = 1, ): + """ + similar with triton_make_row_id_map, but output an indices tensor as well + Args: + routing_map: [n_tokens, n_experts] + num_out_tokens: sum(round_up_to(n_tokens, multiple_of)) + multiple_of: padding the tokens of each expert to this value + Returns: + row_in_map: [n_tokens, n_experts] + row_indices: [num_out_tokens] + """ n_tokens, n_experts = routing_map.shape T = 128 block_counts = torch.empty((T, n_experts), dtype=torch.int32, @@ -208,15 +216,17 @@ def index_select_kernel(x_ptr, out_ptr, scale_ptr, scale_out_ptr, index_ptr, M, tl.store(scale_out_ptr + dst_idx, scale, mask=dst_idx < M) -""" -index select for quantized tensor -x: [bs, dim] -x_scale: [bs] -indices: [K] -""" - - def triton_index_select(x, indices, scale=None, out=None, scale_out=None): + """ + index select for quantized tensor + Args: + x: [bs, dim] + indices: [K] + scale: [bs] + Returns: + out: output of selected x + scale_out: scale of selected scale + """ # row-wise read, row-wise write M, N = x.shape E = indices.shape[0] @@ -311,22 +321,6 @@ def fill_padded_token_with_zero_kernel(data_ptr, scale_ptr, probs_ptr, tl.store(probs_ptr + i, 0.0) -""" -gather with mask map -inp: [num_tokens, hidden_size], rowwise_data -scale: [num_tokens, scale_size], rowwise_scale_inv -prob: [num_tokens], router prob -row_id_map: [n_experts, num_tokens] - index >= 0: row index of output tensor - index == -1: ignore - Note: index may not be contiguous -num_out_tokens: output token count, including padding tokens -contiguous: whether indices in row_id_map is contiguous - False means padded -token_per_expert: [num_experts], token count per expert, non-blocking cuda tensor -""" - - def triton_permute_with_mask_map( inp: torch.Tensor, scale: torch.Tensor, @@ -336,6 +330,28 @@ def triton_permute_with_mask_map( contiguous: bool = True, tokens_per_expert: Optional[torch.Tensor] = None ): + """ + gather quantized tensor with row id map + Args: + inp: [num_tokens, hidden_size], rowwise quantized tensor + scale: [num_tokens], quantization scale + probs: router prob, used as weight + row_id_map: [n_experts, num_tokens] + index >= 0: row index of output tensor + index == -1: ignore + Note: index may not be contiguous + num_out_tokens: output token count, including padding tokens + contiguous: whether indices in row_id_map is contiguous, + False means padded + tokens_per_expert: [num_experts], token count per expert, + non-blocking cuda tensor + + Returns: + output: permuted quantized tensor + permuted_scale: permuted quantization scale + permuted_probs: permuted router prob + + """ num_tokens, hidden_size = inp.shape num_tokens_, num_experts = row_id_map.shape # not transposed assert num_tokens == num_tokens_ @@ -489,21 +505,6 @@ def batch_smooth_transpose_smooth_permute_kernel(x_ptr, scale_ptr, oss_ptr, toffs += H -""" -used for smooth backward in 0.12 -`x`: dy, may be smooth quantized, it should be gather, optional requantized, padded to multiple of 32 and tranposed -x: [bs, dim] -scale: [bs], optional -org_smooth_scale: [dim], optional -smooth_scales: [n_experts, dim], reversed -token_count_per_expert: [n_experts], tensor of token count per expert -splits: [n_experts], list of token_count_per_expert -indices: [sum(tokens_per_experts)] -x_q: [sum(roundup(tokens_per_experts)) * dim] -x_scale: [sum(roundup(tokens_per_experts))] -""" - - def triton_batch_transpose_smooth_permute_with_indices(x, scale, org_smooth_scale, @@ -511,8 +512,26 @@ def triton_batch_transpose_smooth_permute_with_indices(x, indices, token_count_per_expert, splits, - x_q=None, x_scale=None, + x_q=None, + x_scale=None, round_scale=False): + """ + used for smooth quantization backward in megatron 0.12, + x is gathered, requantized, padded to multiple of 32 and tranposed + Args: + x: dy, [bs, dim], it is smooth quantized + scale: [bs], quantized scale + org_smooth_scale: [dim] + smooth_scales: [n_experts, dim] + indices: [sum(tokens_per_experts)] + token_count_per_expert: [n_experts], tensor of token count per expert + splits: [n_experts], list of token_count_per_expert + round_scale: round quantization scale to power of 2 + + Returns: + x_q: [sum(roundup(tokens_per_experts)) * dim] + x_scale: [sum(roundup(tokens_per_experts))] + """ # row-wise read, row-wise write M, N = x.shape n_expert = len(splits) @@ -596,25 +615,32 @@ def smooth_weighted_permute_with_indices_kernel(grads_ptr, tl.store(q_ptr + si * N + i * N + tl.arange(0, N), xq) -""" -select and smooth and quant, used in 0.11 all2all moe -x: [bs, dim] -smooth_scales: [n_experts, dim] -indices: [n_experts*topk] -x_q: [bs*topk, dim] -x_scale: [bs*topk] -""" - - -def triton_smooth_weighted_permute_with_indices(grads, tokens, +def triton_smooth_weighted_permute_with_indices(grads, + tokens, smooth_scales, token_count_per_expert, - indices, x_q=None, + indices, + x_q=None, x_scale=None, x_sum=None, reverse=False, round_scale=False): - # row-wise read, row-wise write + """ + select and smooth and quant, used in megatron 0.11 all2all moe + Args: + grads: [bs, dim] + tokens: [bs, dim] + smooth_scales: [n_experts, dim] + token_count_per_expert: [n_experts] + indices: [n_experts*topk] + reverse: whether scale is 1/scale + round_scale: whether round scale to power of 2 + + Returns: + x_q: [bs*topk, dim] + x_scale: [bs*topk] + x_sum: [bs*topk] + """ M, N = grads.shape n_expert, n = smooth_scales.shape assert N == n, f'{N=} {n=}' @@ -694,23 +720,31 @@ def smooth_permute_with_indices_kernel(grads_data_ptr, tl.store(q_ptr + i * N + tl.arange(0, N), xq) -""" -select and smooth and quant -grad_data: [bs, dim] -grad_scale: [bs, dim/128] -smooth_scales: [n_experts, dim] -indices: [n_experts*topk] -x_q: [bs*topk, dim] -x_scale: [bs*topk] -""" - - -def triton_smooth_permute_with_indices(grad_data, grad_scale, +def triton_smooth_permute_with_indices(grad_data, + grad_scale, smooth_scales, token_count_per_expert, - indices, x_q=None, - x_scale=None, reverse=False, + indices, + x_q=None, + x_scale=None, + reverse=False, round_scale=False): + """ + select and smooth and quant + Args: + grad_data: [bs, dim] + grad_scale: [bs] + smooth_scales: [n_experts, dim] + token_count_per_expert: [n_experts] + indices: [n_experts*topk] + x_q: [bs*topk, dim] + x_scale: [bs*topk] + reverse: + round_scale: + + Returns: + + """ # row-wise read, row-wise write M, N = grad_data.shape n_expert, n = smooth_scales.shape @@ -796,14 +830,6 @@ def smooth_permute_with_mask_map_kernel(grads_data_ptr, quant_data_ptr, mask=mask) -# """ -# gather and optional dequant and smooth quant -# inp: [num_tokens, hidden_size], rowwise_data -# row_id_map: [n_experts, num_tokens], indices -# scale: [num_tokens, hs], rowwise_scale_inv, optional -# num_tokens: [n_experts] -# smooth_scale_ptrs: [n_experts, hidden_size] -# """ def triton_smooth_permute_with_mask_map( inp: torch.Tensor, row_id_map: torch.Tensor, @@ -816,6 +842,24 @@ def triton_smooth_permute_with_mask_map( reverse=True, round_scale=False ): + """ + gather and optional dequant and smooth quant + + Args: + inp: [num_tokens, hidden_size], rowwise quantized tensor + row_id_map: [n_experts, num_tokens], indices + scale: [num_tokens, hs], rowwise_scale_inv, optional + num_tokens: [n_experts] + num_experts: + num_out_tokens: + hidden_size: + smooth_scales: [n_experts, hidden_size] + reverse: + round_scale: + + Returns: + + """ assert row_id_map.shape[1] == num_experts output = torch.empty((num_out_tokens, hidden_size), dtype=torch.float8_e4m3fn, @@ -828,7 +872,6 @@ def triton_smooth_permute_with_mask_map( (num_out_tokens,), dtype=torch.float32, device=inp.device ) - # print(f'{inp.shape=} {row_id_map.shape=} {num_tokens=} {num_out_tokens=}') sm = torch.cuda.get_device_properties(inp.device).multi_processor_count T = triton.cdiv(num_tokens, sm) grid = (num_experts, sm) @@ -847,84 +890,3 @@ def triton_smooth_permute_with_mask_map( round_scale ) return output, permuted_scale - - -@triton.jit -def deprecated_smooth_permute_with_mask_map_kernel(grads_data_ptr, - quant_data_ptr, - mask_map_ptr, - smooth_scale_ptr, - quant_scale_ptr, M, T, - N: tl.constexpr, - REVERSE: tl.constexpr, - ROUND: tl.constexpr): - eid = tl.program_id(axis=0) - bid = tl.program_id(axis=1) - n_experts = tl.num_programs(axis=0) - - # smooth_scale_ptr = tl.load(smooth_scale_ptrs + eid).to(tl.pointer_type(tl.float32)) - smooth_scale = tl.load(smooth_scale_ptr + eid * N + tl.arange(0, N)) - if not REVERSE: - smooth_scale = 1.0 / smooth_scale - for i in range(bid * T, tl.minimum(bid * T + T, M)): - index = tl.load(mask_map_ptr + i * n_experts + eid) - mask = index >= 0 - if index >= 0: - x = tl.load(grads_data_ptr + i * N + tl.arange(0, N), mask=mask).to( - tl.float32) - - x *= smooth_scale - x_max = tl.max(tl.abs(x)) - - scale = tl.maximum(x_max / 448.0, 1e-30) - if ROUND: - scale = tl.exp2(tl.ceil(tl.log2(scale))) - - tl.store(quant_scale_ptr + index, scale, mask=mask) - - x /= scale - xq = x.to(quant_data_ptr.dtype.element_ty) - tl.store(quant_data_ptr + index * N + tl.arange(0, N), xq, - mask=mask) - - -# """ -# gather and smooth quant -# inp: [num_tokens, hidden_size], rowwise_data -# row_id_map: [n_experts, num_tokens], indices -# num_tokens: [n_experts] -# smooth_scale_ptrs: [n_experts, hidden_size] -# """ -def triton_deprecated_smooth_permute_with_mask_map( - inp: torch.Tensor, - row_id_map: torch.Tensor, - num_tokens: int, - num_experts: int, - num_out_tokens: int, - hidden_size: int, - smooth_scales: torch.Tensor, - reverse=True, - round_scale=False -): - assert row_id_map.shape[1] == num_experts - output = torch.empty((num_out_tokens, hidden_size), dtype=inp.dtype, - device=inp.device) - permuted_scale = torch.empty( - (num_out_tokens,), dtype=torch.float32, device=inp.device - ) - sm = torch.cuda.get_device_properties(inp.device).multi_processor_count - T = triton.cdiv(num_tokens, sm) - grid = (num_experts, sm) - deprecated_smooth_permute_with_mask_map_kernel[grid]( - inp, - output, - row_id_map, - smooth_scales, - permuted_scale, - num_tokens, - T, - hidden_size, - reverse, - round_scale, - ) - return output, permuted_scale diff --git a/linghe/utils/loss.py b/linghe/utils/loss.py index b8c735c..1eae2fa 100644 --- a/linghe/utils/loss.py +++ b/linghe/utils/loss.py @@ -44,6 +44,15 @@ def softmax_cross_entropy_forward_kernel(logit_ptr, label_ptr, loss_ptr, TODO: support distributed loss with pytorch ongoing nvshmem feature """ def triton_softmax_cross_entropy_forward(logits, labels): + """ + compute token-wise softmax cross entropy loss + Args: + logits: logits tensor + labels: labels tensor + + Returns: + loss of each token + """ M, N = logits.shape device = logits.device loss = torch.empty((M,), device=device, dtype=torch.float32) @@ -93,6 +102,18 @@ def softmax_cross_entropy_backward_kernel(logit_ptr, label_ptr, sum_exp_ptr, def triton_softmax_cross_entropy_backward(logits, labels, sum_exp, max_logit, input_grad, output_grad=None): + """ + backward of softmax cross entropy loss + Args: + logits: logit tensor, [bs, dim] + labels: label tensor, [bs] + sum_exp: [bs] + max_logit: [bs] + input_grad: gradient, [bs, dim] + + Returns: + output_grad: [bs, dim] + """ M, N = logits.shape device = logits.device if output_grad is None: diff --git a/linghe/utils/norm.py b/linghe/utils/norm.py index 085af0a..c55745c 100644 --- a/linghe/utils/norm.py +++ b/linghe/utils/norm.py @@ -2,7 +2,7 @@ import torch import triton import triton.language as tl - +from typing import Optional @triton.jit @@ -27,6 +27,15 @@ def rms_norm_forward_kernel(x_ptr, weight_ptr, out_ptr, eps, M, T, def triton_rms_norm_forward(x, weight, eps=1e-6, out=None): + """ + rms norm + Args: + x: input tensor + weight: weight of rms norm + eps: epsilon of rms norm + Returns: + out: output tensor + """ # row-wise read, row-wise write M, N = x.shape W = 8192 // N @@ -256,10 +265,35 @@ def rms_norm_and_block_quant_forward_t_kernel(x_ptr, -def triton_rms_norm_and_block_quant_forward(x, weight, eps=1e-6, - out=None, scale=None, rms=None, - round_scale=False, - output_mode=2): +def triton_rms_norm_and_block_quant_forward(x: torch.Tensor, + weight: torch.Tensor, + eps: float = 1e-6, + out: Optional[torch.Tensor] = None, + scale: Optional[torch.Tensor] = None, + rms: Optional[torch.Tensor] = None, + round_scale: bool = False, + output_mode: int = 2): + """ + Fused RMSNorm forward and block quantization. + Args: + x: Input tensor, shape [M, N] + weight: RMSNorm weight, shape [N] + eps: epsilon value for L2 normalization. + out: output of quantization data + scale: output of quantization scale. + rms: output of rms + round_scale: Set whether to force power of 2 scales. + output_mode: one of {0, 1, 2}. + 0: only output non-transpose tensor + 1: only output transposed tensor + 2: return both + Returns: + out: quantization data + scale: quantization scale + rms: Reciprocal of the root mean square of the input calculated over the last dimension. + transpose_output: quantization data of transposed gradient + transpose_scale: quantization scale of transposed gradient + """ # row-wise read, row-wise write M, N = x.shape assert N <= 8192 and 8192 % N == 0 @@ -369,13 +403,19 @@ def group_norm_gate_forward_kernel(x_ptr, gate_ptr, weight_ptr, out_ptr, eps, bs tl.store(out_ptr + offs, x) -""" -x: [bs, length, n_heads, head_dim], output of attn -gate: [length, bs, dim] -weight: [dim] -output: [length, bs, dim] -""" -def triton_group_norm_gate_forward(x, gate, weight, eps=1e-6, group_size=4): +def triton_group_norm_gate_forward(x: torch.Tensor, gate, weight, eps=1e-6, group_size=4): + """ + norm and gate in linear attention + Args: + x: output of attn, [bs, length, n_heads, head_dim] + gate: gate tensor, [length, bs, dim] + weight: rms norm weight, [dim] + eps: epsilon of rms norm + group_size: group size of group rms norm + + Returns: + output tensor + """ # row-wise read, row-wise write length, bs, dim = gate.shape assert dim <= 8192 and triton.next_power_of_2(dim) == dim and triton.next_power_of_2(group_size) == group_size @@ -485,4 +525,105 @@ def triton_group_norm_gate_backward(grad_output, x, gate, weight, eps=1e-6, grou num_warps=8 ) dw = tmp_dw.sum(dim=0).to(weight.dtype) - return dx, dg, dw \ No newline at end of file + return dx, dg, dw + + + +@triton.jit +def rms_norm_and_smooth_quant_forward_kernel(x_ptr, weight_ptr, smooth_scale_ptr, + out_ptr, scale_ptr, max_ptr, rms_ptr, + eps, + M, + T, + N: tl.constexpr, + W: tl.constexpr, + CALIBRATE: tl.constexpr, + OUTPUT: tl.constexpr, + ROUND: tl.constexpr): + pid = tl.program_id(axis=0) + # row-wise read, row-wise write + weight = tl.load(weight_ptr + tl.arange(0, N)).to(tl.float32)[None, :] + smooth_scale = tl.load(smooth_scale_ptr + tl.arange(0, N))[None, :] + smooth_scale = 1.0 / tl.maximum(smooth_scale, 1e-30) + if CALIBRATE: + # triton 3.3.1 has bug with N = 2048 and calibrate=True + maxs = tl.zeros((N, ), dtype=tl.float32) + offs = pid * W * T * N + tl.arange(0, W)[:, None] * N + tl.arange(0, N)[ + None, :] + for i in range(T): + indices = pid * W * T + i * W + tl.arange(0, W) + x = tl.load(x_ptr + offs, mask=indices[:, None] < M).to(tl.float32) + rms = 1/tl.sqrt(tl.sum(x * x, axis=1) / N + eps) + if OUTPUT: + tl.store(rms_ptr + indices, rms, mask=indices < M) + x = x * rms[:, None] * weight + + if CALIBRATE: + maxs = tl.maximum(maxs, tl.max(tl.abs(x),0)) + + x = x * smooth_scale + scale = tl.maximum(tl.max(tl.abs(x), 1) / 448.0, 1e-30) + if ROUND: + scale = tl.exp2(tl.ceil(tl.log2(scale))) + q = (x / scale[:, None]).to(out_ptr.dtype.element_ty) + tl.store(scale_ptr + indices, scale, mask=indices < M) + tl.store(out_ptr + offs, q, mask=indices[:, None] < M) + offs += N * W + + if CALIBRATE: + tl.store(max_ptr + pid * N + tl.arange(0, N), maxs) + + +# rms is used for moe routing, it is stored as 1/rms +def triton_rms_norm_and_smooth_quant_forward(x, weight, smooth_scale=None, + eps=1e-6, + out=None, scale=None, rms=None, + calibrate=False, + output_rms=False, + round_scale=False): + """ + + """ + M, N = x.shape + assert N <= 8192 and 8192 % N == 0 + device = x.device + + if out is None: + out = torch.empty((M, N), device=device, dtype=torch.float8_e4m3fn) + + if scale is None: + scale = torch.empty((M,), device=device, dtype=torch.float32) + W = 8192 // N + T = 8 if M // W >= 4096 else 4 + assert M % (T * W) == 0 + g = M // (T * W) + if calibrate: + maxs = torch.empty((g, N), dtype=torch.float32, device=device) + else: + maxs = None + if output_rms and rms is None: + rms = torch.empty((M,), dtype=torch.float32, device=device) + grid = (g,) + rms_norm_and_smooth_quant_forward_kernel[grid]( + x, + weight, + smooth_scale, + out, + scale, + maxs, + rms, + eps, + M, + T, + N, + W, + calibrate, + output_rms, + round_scale, + num_stages=3, + num_warps=2 if N == 2048 else 4 + ) + if calibrate: + maxs = maxs.amax(0) + + return out, scale, maxs, rms diff --git a/linghe/utils/rearange.py b/linghe/utils/rearange.py index c1ed887..58e5d6c 100644 --- a/linghe/utils/rearange.py +++ b/linghe/utils/rearange.py @@ -32,15 +32,20 @@ def split_and_cat_kernel(x_ptr, y_ptr, scale_ptr, scale_output_ptr, count_ptr, mask=i * K + tl.arange(0, K) < count) -""" -select and smooth and quant -x: [bs, dim] -counts: [n_split] -indices: [n_split] -""" - - def triton_split_and_cat(x, counts, indices, scales=None): + """ + split x to multiple tensors and cat with indices, + it is used for permutation in moe + Args: + x: [bs, dim] + counts: [n_split] + indices: [n_split] + scales: [bs] + + Returns: + y: output tensor + output_scales: output scales if scales is not None + """ M, N = x.shape n_split = counts.shape[0] device = x.device diff --git a/linghe/utils/reduce.py b/linghe/utils/reduce.py index 646d8b0..72b3b3d 100644 --- a/linghe/utils/reduce.py +++ b/linghe/utils/reduce.py @@ -46,8 +46,19 @@ def abs_max_kernel(x_ptr, tl.store(output_ptr + pid * W + tl.arange(0, W), scale) -# update weight smooth scale for next step with x input def triton_abs_max(x, scale=None, smooth_scale=None, min_value=1e-30, axis=0): + """ + columnwise abs max of x, it is used in smooth quantization + Args: + x: input tensor, may be quantized tensor + scale: quantization scale if x is quantized + smooth_scale: optional smooth scale + min_value: output = max(max(abs(x,0)), min_value) + axis: reduce axis + + Returns: + max tensor + """ assert axis == 0 N = x.size(-1) M = x.numel() // N @@ -95,6 +106,14 @@ def batch_count_zero_kernel(input_ptrs, size_ptr, count_ptr, B: tl.constexpr): def triton_batch_count_zero(xs): + """ + count zero in tensor list, it is used to monitor zeros in gradient tensor + Args: + xs: input tensors + + Returns: + a single-value int64 tensor + """ device = xs[0].device sizes = torch.tensor([x.numel() for x in xs], dtype=torch.int64, device=device) @@ -142,6 +161,15 @@ def batch_sum_with_ord_kernel(input_ptrs, size_ptr, count_ptr, B: tl.constexpr, def triton_batch_sum_with_ord(xs, ord=2): + """ + return sum(abs(x)**ord). + Args: + xs: Tensor lists. + ord: the order of tensor. + + Returns: + a single-value fp32 tensor + """ assert ord in (1, 2) device = xs[0].device sizes = torch.tensor([x.numel() for x in xs], dtype=torch.int64, diff --git a/linghe/utils/rope.py b/linghe/utils/rope.py index beb881f..85f5696 100644 --- a/linghe/utils/rope.py +++ b/linghe/utils/rope.py @@ -80,15 +80,18 @@ def half_rope_forward_kernel(q_ptr, k_ptr, freqs_ptr, qo_ptr, ko_ptr, B, 0, h)[:, None] + tl.arange(0, D)[None, :], k) -""" -apply norm to qk, then apply rope to qk, then transpose qkv -q: [len, bs, q_head, head_dim] -k: [len, bs, kv_head, head_dim] -v: [len, bs, kv_head, head_dim] -""" - - def triton_half_rope_forward(q, k, freqs): + """ + apply norm to qk, then apply half rope to qk + Args: + q: query tensor, [len, bs, q_head, head_dim] + k: key tensor, [len, bs, kv_head, head_dim] + freqs: rope freqs + + Returns: + qo: + ko: + """ L, B, H, D = q.shape h = k.shape[2] assert freqs.shape[1] == D // 2 @@ -170,13 +173,6 @@ def half_rope_backward_kernel(q_ptr, k_ptr, freqs_ptr, 0, D)[None, :], k) -""" -apply norm to qk, then apply rope to qk, then transpose qkv -q: [len, bs, q_head, head_dim] -k: [len, bs, kv_head, head_dim] -v: [len, bs, kv_head, head_dim] -""" - def triton_half_rope_backward(q_grad, k_grad, freqs, inplace=False): assert inplace @@ -325,16 +321,30 @@ def qk_norm_and_half_rope_forward_kernel(qkv_ptr, 0, D)[None, :], v1) -""" -use qkv as input, to reduce redundant gradient copy in backward -split qkv, apply norm to qk, apply rope to qk -qkv: [len, bs, kv_head*(q_head//kv_head + 2 ) * head_dim)] -""" - - def triton_qk_norm_and_half_rope_forward(qkv, q_norm_weight, k_norm_weight, freqs, H=32, h=4, eps=1e-6, interleave=True, transpose=False): + + """ + split qkv to q/k/v, apply qk norm and half rope to q/k, + transpose q/k/v to flash-attention layout + Args: + qkv: QKV tensor with size of [S, B, dim], heads are interleaved + q_norm_weight: rms norm weight for query + k_norm_weight: rms norm weight for key + freqs: Freqs tensor based on half dim. + H: Number of attention heads. + h: Number of key/value heads. + eps: epsilon value for L2 normalization. + interleave: whether head of qkv is interleaved, i.e., [qqkvqqkv] + transpose: whether qkv is tranposed, i.e., [S, B, dim], + only support transpose format currently + Returns: + qo: shape [B, S, H, head_dim] + ko: shape [B, S, h, head_dim] + vo: shape [B, S, h, head_dim] + """ + assert transpose L, B, Dim = qkv.shape stride = qkv.stride(1) # qkv may be a slice of a tensor @@ -532,17 +542,28 @@ def qk_norm_and_half_rope_backward_kernel(gq_ptr, gk_ptr, gv_ptr, 0, D)[None, :], v1) -""" -apply norm to qk, then apply rope to qk -q: [len, bs, q_head, head_dim] -k: [len, bs, kv_head, head_dim] -v: [len, bs, kv_head, head_dim] -""" - - def triton_qk_norm_and_half_rope_backward(gq, gk, gv, qkv, q_norm_weight, k_norm_weight, freqs, eps=1e-6, transpose=False, interleave=True): + """ + backward kernel of triton_qk_norm_and_half_rope_forward + Args: + gq: gradient of qo, [len, bs, q_head, head_dim] + gk: gradient of ko, [len, bs, q_head, head_dim] + gv: gradient of vo, [len, bs, q_head, head_dim] + qkv: input qkv + q_norm_weight: + k_norm_weight: + freqs: + eps: + transpose: + interleave: + + Returns: + dqkv: gradient of qkv + dqw: gradient of q_norm_weight + dkw: gradient of k_norm_weight + """ assert transpose B, L, H, D = gq.shape stride = qkv.stride(1) diff --git a/linghe/utils/scatter.py b/linghe/utils/scatter.py index 7549015..bb39945 100644 --- a/linghe/utils/scatter.py +++ b/linghe/utils/scatter.py @@ -3,12 +3,12 @@ Copyright (c) Ant Financial Service Group and its affiliates. """ +from typing import Optional import torch import triton import triton.language as tl -# for megatron 0.11 scatter_add @triton.jit def aligned_scatter_add_kernel(x_ptr, o_ptr, indices_ptr, weights_ptr, M, @@ -30,7 +30,21 @@ def aligned_scatter_add_kernel(x_ptr, o_ptr, indices_ptr, weights_ptr, M, tl.store(o_ptr + pid * N + offs, sums) -def triton_aligned_scatter_add(x, outputs, indices, weights=None): +def triton_aligned_scatter_add(x: torch.Tensor, + outputs: torch.Tensor, + indices: torch.Tensor, + weights: Optional[torch.Tensor] = None): + """ + scatter_add for megatron 0.11 + Args: + x: input tensor + outputs: output tensor + indices: gather indices + weights: rowwise weight, it is router prob in MoE router + + Returns: + output tensor + """ M, N = x.shape m = outputs.size(0) @@ -82,6 +96,16 @@ def fp32_to_bf16_kernel(x_ptr, o_ptr, M, T, N: tl.constexpr): def triton_scatter_add(x, outputs, indices): + """ + naive version of scatter add, very slow + Args: + x: input tensor + outputs: output tensor + indices: indices + + Returns: + outputs + """ M, N = x.shape float_outputs = torch.zeros(outputs.shape, dtype=torch.float32, @@ -149,18 +173,22 @@ def unpermute_with_mask_map_kernel(grads_ptr, probs_ptr, mask_map_ptr, tl.store(output_ptr + pid * N + tl.arange(0, N), sums) -# """ -# gather and smooth quant -# inp: [num_tokens, hidden_size], rowwise_data -# row_id_map: [n_experts, num_tokens], indices -# prob: [num_out_tokens], rowwise_scale_inv -# """ - def triton_unpermute_with_mask_map( grad: torch.Tensor, row_id_map: torch.Tensor, probs: torch.Tensor, ): + """ + scatter add with row id map + Args: + grad: gradient tensor, [num_out_tokens, hidden_size] + row_id_map: row id map, [n_experts, num_tokens] + probs: [num_out_tokens] + + Returns: + output: [num_tokens, hidden_size] + restore_probs: [num_tokens, num_experts] + """ hidden_size = grad.shape[1] num_tokens, num_experts = row_id_map.shape # not transposed diff --git a/linghe/utils/silu.py b/linghe/utils/silu.py index c0d8aa1..bfcf910 100644 --- a/linghe/utils/silu.py +++ b/linghe/utils/silu.py @@ -3,11 +3,159 @@ Copyright (c) Ant Financial Service Group and its affiliates. """ +from typing import Optional import torch import triton import triton.language as tl + +@triton.jit +def weighted_silu_forward_kernel(x_ptr, weight_ptr, out_ptr, M, T, + N: tl.constexpr, + n: tl.constexpr, + W: tl.constexpr, + WEIGHT: tl.constexpr): + pid = tl.program_id(axis=0) + + row_offs = pid * W * T * n + tl.arange(0, W)[:, None] * n + col_offs = tl.arange(0, n)[None, :] + + for i in range(T): + indices = pid * W * T + i * W + tl.arange(0, W) + mask = indices[:, None] < M + x1 = tl.load(x_ptr + row_offs * 2 + col_offs, mask=mask).to(tl.float32) + x2 = tl.load(x_ptr + n + row_offs * 2 + col_offs, mask=mask).to( + tl.float32) + if WEIGHT: + w = tl.load(weight_ptr + indices, mask=indices < M).to(tl.float32)[:, + None] + x = x1 / (1 + tl.exp(-x1)) * x2 * w + else: + x = x1 / (1 + tl.exp(-x1)) * x2 + tl.store(out_ptr + row_offs + col_offs, x, mask=mask) + row_offs += n * W + + +# used in bf16 moe +def triton_weighted_silu_forward(x, weight=None, out=None): + """ + compute silu(x)*weight, used in bf16/fp16 training with MoE + Args: + x: input tensor + weight: tokenwise weight + Returns: + out: output tensor + """ + # row-wise read, row-wise write + M, N = x.shape + assert N <= 8192 + device = x.device + if out is None: + out = torch.empty((M, N // 2), device=device, dtype=x.dtype) + WEIGHT = weight is not None + W = 8192 // N + T = 8 + grid = (triton.cdiv(M, T * W),) + weighted_silu_forward_kernel[grid]( + x, + weight, + out, + M, T, + N, + N // 2, + W, + WEIGHT, + num_stages=3, + num_warps=8 + ) + return out + + +@triton.jit +def weighted_silu_backward_kernel(g_ptr, x_ptr, weight_ptr, dx_ptr, dw_ptr, M, + T, + N: tl.constexpr, + n: tl.constexpr, + W: tl.constexpr, + WEIGHT: tl.constexpr): + pid = tl.program_id(axis=0) + + offs = pid * W * T * N + tl.arange(0, W)[:, None] * N + tl.arange(0, n)[ + None, :] + hoffs = pid * W * T * n + tl.arange(0, W)[:, None] * n + tl.arange(0, n)[ + None, :] + for i in range(T): + mask = pid * W * T + i * W + tl.arange(0, W) + x1 = tl.load(x_ptr + offs, mask=mask[:, None] < M).to(tl.float32) + x2 = tl.load(x_ptr + offs + n, mask=mask[:, None] < M).to(tl.float32) + g = tl.load(g_ptr + hoffs, mask=mask[:, None] < M).to(tl.float32) + if WEIGHT: + w = tl.load(weight_ptr + mask, mask=mask < M).to(tl.float32)[:, None] + sigmoid = 1 / (1 + tl.exp(-x1)) + dw = tl.sum(x1 * sigmoid * x2 * g, 1) + tl.store(dw_ptr + mask, dw, mask=mask < M) + dx1 = g * x2 * w * sigmoid * (1 + x1 * tl.exp(-x1) * sigmoid) + tl.store(dx_ptr + offs, dx1, mask=mask[:, None] < M) + + dx2 = g * x1 * sigmoid * w + tl.store(dx_ptr + offs + n, dx2, mask=mask[:, None] < M) + else: + sigmoid = 1 / (1 + tl.exp(-x1)) + dx1 = g * x2 * sigmoid * (1 + x1 * tl.exp(-x1) * sigmoid) + tl.store(dx_ptr + offs, dx1, mask=mask[:, None] < M) + + dx2 = g * x1 * sigmoid + tl.store(dx_ptr + offs + n, dx2, mask=mask[:, None] < M) + offs += N * W + hoffs += n * W + + +def triton_weighted_silu_backward(g: torch.Tensor, + x: torch.Tensor, + weight: Optional[torch.Tensor] = None): + """ + backward of triton_weighted_silu_forward + Args: + g: gradient tensor + x: input tensor + weight: weight tensor + + Returns: + dx: gradient of x + dw: gradient of weight + """ + # row-wise read, row-wise write + M, N = x.shape + assert N <= 8192 + device = x.device + if weight is not None: + dw = torch.empty(weight.shape, device=device, dtype=x.dtype) + WEIGHT = True + else: + dw = None + WEIGHT = False + dx = torch.empty((M, N), device=device, dtype=x.dtype) + W = 8192 // N + T = 8 + grid = (triton.cdiv(M, W*T),) + weighted_silu_backward_kernel[grid]( + g, + x, + weight, + dx, + dw, + M, T, + N, + N // 2, + W, + WEIGHT, + num_stages=3, + num_warps=8 + ) + return dx, dw + + @triton.jit def silu_and_block_quant_forward_kernel(x_ptr, out_ptr, scale_ptr, @@ -64,11 +212,27 @@ def silu_and_block_quant_forward_kernel(x_ptr, tl.trans(xq), mask=indices[None, :] < M) -# used in shared expert -def triton_silu_and_block_quant_forward(x, out=None, scale=None, +def triton_silu_and_block_quant_forward(x, + out=None, + scale=None, round_scale=False, output_mode=2): - # row-wise read, row-wise write + """ + fused silu and blockwise quantization, used in shared expert + Args: + x: input tensor + round_scale: whether round scale to power of 2 + output_mode: one of {0, 1, 2} + 0: only output non-transposed quantized tensor + 1: only output transposed quantized tensor + 2: output both + + Returns: + out: quantized tensor + scale: quantization scale + transpose_output: quantized tensor of transposed output + transpose_scale: quantization scale of transposed output + """ M, N = x.shape n = N // 2 device = x.device @@ -177,7 +341,19 @@ def silu_and_block_quant_backward_kernel(g_ptr, x_ptr, # used in shared expert def triton_silu_and_block_quant_backward(g, x, round_scale=False): - # row-wise read, row-wise write + """ + backward of triton_silu_and_block_quant_forward + Args: + g: gradient + x: input tensor + round_scale: whether round to power of 2 + + Returns: + dx: quantized non-transposed gradient + dx_scale: scales of quantization non-transposed gradient + transpose_dx: quantized transposed gradient + transpose_dx_scale: scales of quantization transposed gradient + """ M, N = x.shape n = N // 2 device = x.device @@ -281,7 +457,7 @@ def batch_weighted_silu_and_block_quant_forward_kernel(x_ptr, weight_ptr, mask=indices[None, :] < count) -# used in routed experts + def triton_batch_weighted_silu_and_block_quant_forward(x, weight, counts, @@ -290,7 +466,25 @@ def triton_batch_weighted_silu_and_block_quant_forward(x, scale=None, round_scale=False, output_mode=2): - # row-wise read, row-wise write + """ + silu and blockwise quantize activation in routed experts + Args: + x: activation tensor in routed experts + weight: router prob tensor + counts: cuda tensor of token count per expert + splits: python int list of token count per expert + round_scale: whether round scale to power of 2 + output_mode: one of {0, 1, 2} + 0: only output non-transposed quantized tensor + 1: only output transposed quantized tensor + 2: output both + + Returns: + out: quantized tensor + scale: quantization scale + transpose_output: quantized tensor of transposed output + transpose_scale: quantization scale of transposed output + """ M, N = x.shape n = N // 2 n_experts = counts.shape[0] @@ -307,7 +501,8 @@ def triton_batch_weighted_silu_and_block_quant_forward(x, dtype=torch.float32) # intra layout and inner layput are not consist, # tensors will be viewed after splitting - scale = torch.empty((M * n // 128,), device=device, dtype=torch.float32) + if scale is None: + scale = torch.empty((M * n // 128,), device=device, dtype=torch.float32) if M == 0: return out, scale, transpose_output, transpose_scale @@ -437,6 +632,22 @@ def triton_batch_weighted_silu_and_block_quant_backward(g, x, weight, counts, splits=None, round_scale=False): + """ + backward of triton_batch_weighted_silu_and_block_quant_forward + Args: + g: gradient + x: input tensor + weight: router prob tensor + counts: cuda tensor of token count per expert + splits: python int list of token count per expert + round_scale: whether round scale to power of 2 + Returns: + dx: quantized non-transposed gradient + dx_scale: scales of quantization non-transposed gradient + dw: gradient of weight + transpose_dx: quantized transposed gradient + transpose_dx_scale: scales of quantization transposed gradient + """ # row-wise read, row-wise write M, N = x.shape n = N // 2 @@ -485,3 +696,744 @@ def triton_batch_weighted_silu_and_block_quant_backward(g, x, weight, ) dw = dws.sum(1, keepdim=True).to(weight.dtype) return dx, dx_scale, dw, transpose_dx, transpose_dx_scale + + + + + +# n is power of 2 +@triton.jit +def silu_and_smooth_quant_forward_kernel(x_ptr, smooth_scale_ptr, out_ptr, scale_ptr, + max_ptr, M, T, n: tl.constexpr, + W: tl.constexpr, ROUND: tl.constexpr, + CALIBRATE: tl.constexpr): + pid = tl.program_id(axis=0) + + row_offs = pid * T * W * n + tl.arange(0, W)[:, None] * n + col_offs = tl.arange(0, n)[None, :] + smooth_scale = tl.load(smooth_scale_ptr + tl.arange(0, n)) + smooth_scale = 1.0 / smooth_scale + if CALIBRATE: + maxs = tl.zeros((W, n), dtype=tl.float32) + + for i in range(T): + indices = pid * T * W + i * W + tl.arange(0, W) + mask = indices[:, None] < M + x1 = tl.load(x_ptr + row_offs * 2 + col_offs, mask=mask).to(tl.float32) + x2 = tl.load(x_ptr + n + row_offs * 2 + col_offs, mask=mask).to( + tl.float32) + x = x1 / (1 + tl.exp(-x1)) * x2 + if CALIBRATE: + maxs = tl.maximum(x.abs(), maxs) + x = x * smooth_scale + scale = tl.maximum(tl.max(x.abs(), 1) / 448, 1e-30) + if ROUND: + scale = tl.exp2(tl.ceil(tl.log2(scale))) + tl.store(scale_ptr + indices, scale, mask=indices < M) + x = (x / scale[:, None]).to(out_ptr.dtype.element_ty) + tl.store(out_ptr + row_offs + col_offs, x, mask=mask) + row_offs += n * W + + if CALIBRATE: + maxs = tl.max(maxs, 0) + tl.store(max_ptr + pid * n + tl.arange(0, n), maxs) + + +# n is NOT power of 2 +@triton.jit +def compatible_silu_and_smooth_quant_forward_kernel(x_ptr, smooth_scale_ptr, out_ptr, + scale_ptr, max_ptr, M, + T: tl.constexpr, n: tl.constexpr, + B: tl.constexpr, + ROUND: tl.constexpr, + CALIBRATE: tl.constexpr): + pid = tl.program_id(axis=0) + + # rowwise read with block size [T, B] + row_offs = pid * T * n + tl.arange(0, T)[:, None] * n + col_offs = tl.arange(0, B)[None, :] + + nb = n // B + maxs = tl.zeros((T,), dtype=tl.float32) + for i in range(nb): + + smooth_scale = tl.load(smooth_scale_ptr + i * B + tl.arange(0, B)) + x1 = tl.load(x_ptr + row_offs * 2 + col_offs).to(tl.float32) + x2 = tl.load(x_ptr + n + row_offs * 2 + col_offs).to(tl.float32) + x = x1 / (1 + tl.exp(-x1)) * x2 + if CALIBRATE: + x_maxs = tl.max(x.abs(), 0) + tl.store(max_ptr + pid * n + i * B + tl.arange(0, B), x_maxs) + x = x / smooth_scale + maxs = tl.maximum(tl.max(x.abs(), 1), maxs) + col_offs += B + + scale = tl.maximum(maxs / 448, 1e-30) + if ROUND: + scale = tl.exp2(tl.ceil(tl.log2(scale))) + tl.store(scale_ptr + pid * T + tl.arange(0, T), scale) + + col_offs = tl.arange(0, B)[None, :] + for i in range(nb): + smooth_scale = tl.load(smooth_scale_ptr + i * B + tl.arange(0, B)) + + x1 = tl.load(x_ptr + row_offs * 2 + col_offs).to(tl.float32) + x2 = tl.load(x_ptr + n + row_offs * 2 + col_offs).to(tl.float32) + x = x1 / (1 + tl.exp(-x1)) * x2 + x = x / smooth_scale + + x = (x / scale[:, None]).to(out_ptr.dtype.element_ty) + tl.store(out_ptr + row_offs + col_offs, x) + col_offs += B + + + + +# used in shared expert +def triton_silu_and_smooth_quant_forward(x, smooth_scale=None, out=None, scale=None, + maxs=None, round_scale=False, + calibrate=False): + """ + + """ + M, N = x.shape + n = N // 2 + device = x.device + if out is None: + out = torch.empty((M, N // 2), device=device, dtype=torch.float8_e4m3fn) + if scale is None: + scale = torch.empty((M,), device=device, dtype=torch.float32) + + if triton.next_power_of_2(N) == N and N <= 8192: + # sm = torch.cuda.get_device_properties(device).multi_processor_count + W = 8192 // N + T = 8 if M//W >= 1024 else 4 + assert M % (T*W) == 0 + g = M//(T*W) + # T = triton.cdiv(M, sm * W) + if maxs is None and calibrate: + maxs = torch.empty((g, n), device=device, dtype=torch.float32) + grid = (g,) + silu_and_smooth_quant_forward_kernel[grid]( + x, + smooth_scale, + out, + scale, + maxs, + M, + T, + n, + W, + round_scale, + calibrate, + num_stages=2, + num_warps=16 + ) + else: + B = 512 + T = 16 + assert n % B == 0 and M % T == 0 + grid = (M // T,) + if maxs is None and calibrate: + maxs = torch.empty((M // T, n), device=device, dtype=torch.float32) + compatible_silu_and_smooth_quant_forward_kernel[grid]( + x, + smooth_scale, + out, + scale, + maxs, + M, + T, + N // 2, + B, + round_scale, + calibrate, + num_stages=2, + num_warps=16 + ) + + if calibrate: + maxs = maxs.amax(0) + + + return out, scale, maxs + + + + + +@triton.jit +def silu_and_smooth_quant_backward_kernel(g_ptr, x_ptr, + smooth_scale_ptr, + transpose_smooth_scale_ptr, + dx_ptr, dx_scale_ptr, + transpose_dx_ptr, + transpose_dx_scale_ptr, + M, + n: tl.constexpr, + T: tl.constexpr, + B: tl.constexpr, + REVERSE: tl.constexpr, + ROUND: tl.constexpr): + pid = tl.program_id(axis=0) + + offs = pid * T * n * 2 + tl.arange(0, T)[:, None] * n * 2 + tl.arange(0, B)[ + None, :] + hoffs = pid * T * n + tl.arange(0, T)[:, None] * n + tl.arange(0, B)[None, + :] + toffs = pid * T + tl.arange(0, B)[:, None] * M + tl.arange(0, T)[None, :] + nb = n // B + maxs = tl.zeros((T, ), dtype=tl.float32) + transpose_smooth_scale = tl.load(transpose_smooth_scale_ptr + pid * T + tl.arange(0, T))[:, None] + for i in range(nb): + smooth_scale_1 = tl.load(smooth_scale_ptr + i * B + tl.arange(0, B)) + smooth_scale_2 = tl.load(smooth_scale_ptr + n + i * B + tl.arange(0, B)) + if not REVERSE: + smooth_scale_1 = 1 / smooth_scale_1 + smooth_scale_2 = 1 / smooth_scale_2 + + x1 = tl.load(x_ptr + offs).to(tl.float32) + x2 = tl.load(x_ptr + offs + n).to(tl.float32) + g = tl.load(g_ptr + hoffs).to(tl.float32) + sigmoid = 1 / (1 + tl.exp(-x1)) + + # x1 = tl.load(x_ptr + offs) + # x2 = tl.load(x_ptr + offs + n) + # g = tl.load(g_ptr + hoffs) + # sigmoid = 1 / (1 + tl.exp(-x1.to(tl.float32))) + + dx1 = g * x2 * sigmoid * ( + 1 + x1 * (1 - sigmoid)) + dx2 = g * x1 * sigmoid + + t_dx = dx1 * transpose_smooth_scale + t_s = tl.maximum(tl.max(tl.abs(t_dx), 0) / 448, 1e-30) + if ROUND: + t_s = tl.exp2(tl.ceil(tl.log2(t_s))) + t_dx = t_dx/t_s + tl.store(transpose_dx_ptr + toffs, tl.trans(t_dx.to(transpose_dx_ptr.dtype.element_ty))) + tl.store(transpose_dx_scale_ptr + pid * n * 2 + i * B + tl.arange(0, B), t_s) + + t_dx = dx2 * transpose_smooth_scale + t_s = tl.maximum(tl.max(tl.abs(t_dx), 0) / 448, 1e-30) + if ROUND: + t_s = tl.exp2(tl.ceil(tl.log2(t_s))) + t_dx = t_dx/t_s + tl.store(transpose_dx_ptr + M * n + toffs, tl.trans(t_dx.to(transpose_dx_ptr.dtype.element_ty))) + tl.store(transpose_dx_scale_ptr + pid * n * 2 + n + i * B + tl.arange(0, B), t_s) + + dx1 = dx1 * smooth_scale_1 + dx2 = dx2 * smooth_scale_2 + + # maxs = tl.maximum( + # tl.maximum(dx1.abs(), dx2.abs()), maxs) + maxs = tl.maximum( + tl.maximum(tl.max(dx1.abs(), 1), tl.max(dx2.abs(), 1)), maxs) + + offs += B + hoffs += B + toffs += B * M + + scale = tl.maximum(maxs / 448, 1e-30) + # scale = tl.maximum(tl.max(maxs, 1) / 448, 1e-30) + + if ROUND: + scale = tl.exp2(tl.ceil(tl.log2(scale))) + tl.store(dx_scale_ptr + pid * T + tl.arange(0, T), scale) + + s = 1 / scale[:, None] + offs = pid * T * n * 2 + tl.arange(0, T)[:, None] * n * 2 + tl.arange(0, B)[ + None, :] + hoffs = pid * T * n + tl.arange(0, T)[:, None] * n + tl.arange(0, B)[None, + :] + for i in range(nb): + smooth_scale_1 = tl.load(smooth_scale_ptr + i * B + tl.arange(0, B)) + smooth_scale_2 = tl.load(smooth_scale_ptr + n + i * B + tl.arange(0, B)) + if not REVERSE: + smooth_scale_1 = 1 / smooth_scale_1 + smooth_scale_2 = 1 / smooth_scale_2 + + x1 = tl.load(x_ptr + offs).to(tl.float32) + x2 = tl.load(x_ptr + offs + n).to(tl.float32) + g = tl.load(g_ptr + hoffs).to(tl.float32) + sigmoid = 1 / (1 + tl.exp(-x1)) + dx1 = g * x2 * sigmoid * ( + 1 + x1 * (1 - sigmoid)) * smooth_scale_1 + dx2 = g * x1 * sigmoid * smooth_scale_2 + + dx1 = (dx1 * s).to(dx_ptr.dtype.element_ty) + dx2 = (dx2 * s).to(dx_ptr.dtype.element_ty) + + tl.store(dx_ptr + offs, dx1) + tl.store(dx_ptr + n + offs, dx2) + offs += B + hoffs += B + +# requant multi-column quantized tensor +@triton.jit +def _requant_kernel(x_ptr, scale_ptr, scales_ptr, + M, + N, + H: tl.constexpr, + W: tl.constexpr + ): + rid = tl.program_id(axis=0) + cid = tl.program_id(axis=1) + offs = rid * H * N + cid * W + tl.arange(0, H)[:, None] * N + tl.arange(0, W)[None, :] + global_scale = tl.load(scale_ptr + rid * H + tl.arange(0, H)) + # scales is stored with column-major format + local_scale = tl.load(scales_ptr + cid * M + rid * H + tl.arange(0, H)) + x = tl.load(x_ptr+offs).to(tl.float32) + rescale = local_scale/global_scale + x = x * rescale[:,None] + tl.store(x_ptr+offs, x) + + +# used in shared expert +def triton_silu_and_smooth_quant_backward(g, x, + smooth_scale=None, + transpose_smooth_scale=None, + reverse=True, + round_scale=False): + """ + + """ + assert round_scale + M, N = x.shape + n = N // 2 + device = x.device + dx = torch.empty((M, N), device=device, dtype=torch.float8_e4m3fn) + dx_scale = torch.empty((M,), device=device, dtype=torch.float32) + scale_shape = (N, ) + transpose_dx = torch.empty((N, M), device=device, dtype=torch.float8_e4m3fn) + transpose_dx_scale = torch.empty(scale_shape, device=device, dtype=torch.float32) + + T = 32 + B = 32 + assert M % T == 0 and n % B == 0 + transpose_dx_scales = torch.empty((M // T, N), device=device, dtype=torch.float32) + grid = (M // T,) + silu_and_smooth_quant_backward_kernel[grid]( + g, + x, + smooth_scale, + transpose_smooth_scale, + dx, + dx_scale, + transpose_dx, + transpose_dx_scales, + M, + n, + T, + B, + reverse, + round_scale, + num_stages=3, + num_warps=2 + ) + transpose_dx_scale = transpose_dx_scales.amax(0) + grid = (N // B, M // T) + _requant_kernel[grid](transpose_dx, transpose_dx_scale, transpose_dx_scales, + N, + M, + B, + T) + + return dx, dx_scale, transpose_dx, transpose_dx_scale + + +@triton.jit +def batch_weighted_silu_and_smooth_quant_forward_kernel(x_ptr, weight_ptr, + smooth_scale_ptr, + out_ptr, + scale_ptr, max_ptr, + count_ptr, + accum_ptr, M, + n: tl.constexpr, + W: tl.constexpr, + ROUND: tl.constexpr, + REVERSE: tl.constexpr, + CALIBRATE: tl.constexpr): + eid = tl.program_id(axis=0) + tid = tl.program_id(axis=1) + sm = tl.num_programs(axis=1) + + count = tl.load(count_ptr + eid) + ei = tl.load(accum_ptr + eid) + si = ei - count + c = tl.cdiv(count, sm * W) + + row_offs = si * n + tid * c * W * n + tl.arange(0, W)[:, None] * n + col_offs = tl.arange(0, n)[None, :] + smooth_scale = tl.load(smooth_scale_ptr + n * eid + tl.arange(0, n)) + if not REVERSE: + smooth_scale = 1.0 / smooth_scale + + if CALIBRATE: + maxs = tl.zeros((W, n), dtype=tl.float32) + + for i in range(c): + indices = tid * c * W + i * W + tl.arange(0, W) + mask = indices[:, None] < count + x1 = tl.load(x_ptr + row_offs * 2 + col_offs, mask=mask).to(tl.float32) + x2 = tl.load(x_ptr + n + row_offs * 2 + col_offs, mask=mask).to( + tl.float32) + + w = tl.load(weight_ptr + si + indices, mask=indices < count).to( + tl.float32)[:, + None] + x = x1 / (1 + tl.exp(-x1)) * x2 + + if CALIBRATE: + maxs = tl.maximum(x.abs(), maxs) + + x *= w * smooth_scale + scale = tl.maximum(tl.max(x.abs(), 1) / 448, 1e-30) + if ROUND: + scale = tl.exp2(tl.ceil(tl.log2(scale))) + tl.store(scale_ptr + si + indices, scale, mask=indices < count) + x = (x / scale[:, None]).to(out_ptr.dtype.element_ty) + tl.store(out_ptr + row_offs + col_offs, x, mask=mask) + row_offs += n * W + + if CALIBRATE: + maxs = tl.max(maxs, 0) + tl.store(max_ptr + eid * sm * n + tid * n + tl.arange(0, n), maxs) + + +# used in routed experts +def triton_batch_weighted_silu_and_smooth_quant_forward(x, + weight, + counts, + smooth_scale=None, + splits=None, + out=None, + scale=None, + round_scale=False, + reverse=False, + calibrate=False): + """ + + """ + M, N = x.shape + n = N // 2 + n_experts = counts.shape[0] + assert N <= 8192 + device = x.device + if out is None: + out = torch.empty((M, n), device=device, dtype=torch.float8_e4m3fn) + + sm = torch.cuda.get_device_properties(device).multi_processor_count + tmp_maxs = None + if scale is None: + scale = torch.empty((M,), device=device, dtype=torch.float32) + if M == 0: + maxs = torch.zeros((n_experts, n), device=device, + dtype=torch.float32) + + elif calibrate: + tmp_maxs = torch.empty((n_experts, sm, n), device=device, + dtype=torch.float32) + maxs = torch.empty((n_experts, n), device=device, + dtype=torch.float32) + else: + maxs = None + + if M == 0: + return out, scale, maxs + + accums = torch.cumsum(counts, 0) + W = 8192 // N + grid = (n_experts, sm) + batch_weighted_silu_and_smooth_quant_forward_kernel[grid]( + x, + weight, + smooth_scale, + out, + scale, + tmp_maxs, + counts, + accums, + M, + n, + W, + round_scale, + reverse, + calibrate, + num_stages=3, + num_warps=16 + ) + if calibrate: + maxs = tmp_maxs.amax(1) + + return out, scale, maxs + + +@triton.jit +def batch_weighted_silu_and_smooth_quant_backward_kernel(g_ptr, x_ptr, + weight_ptr, + smooth_scale_ptr, + transpose_smooth_scale_ptr, + count_ptr, + accum_ptr, + dx_ptr, + dx_scale_ptr, + transpose_dx_ptr, + transpose_dx_scale_ptr, + dw_ptr, + n: tl.constexpr, + T: tl.constexpr, + B: tl.constexpr, + E: tl.constexpr, + REVERSE: tl.constexpr, + ROUND: tl.constexpr): + eid = tl.program_id(axis=0) + pid = tl.program_id(axis=1) + max_block = tl.num_programs(axis=1) + + count = tl.load(count_ptr + eid) + round_count = tl.cdiv(count, 32) * 32 + si = tl.load(accum_ptr + eid) - count + + if pid >= tl.cdiv(count, T): + return + + round_off = tl.sum(tl.where(tl.arange(0, E) < eid, + tl.cdiv(tl.load(count_ptr + tl.arange(0, E)), + 32), 0)) * 32 + + offs = si * n * 2 + pid * T * n * 2 + tl.arange(0, T)[:, + None] * n * 2 + tl.arange(0, B)[ + None, :] + hoffs = si * n + pid * T * n + tl.arange(0, T)[:, None] * n + tl.arange(0, + B)[ + None, + :] + toffs = round_off * n * 2 + pid * T + tl.arange(0, B)[:, + None] * round_count + tl.arange(0, T)[ + None, :] + nb = n // B + maxs = tl.zeros((T,), dtype=tl.float32) + indices = pid * T + tl.arange(0, T) + if REVERSE: + transpose_smooth_scale = tl.load( + transpose_smooth_scale_ptr + si + pid * T + tl.arange(0, T), + mask=indices < count)[:, None] + else: + transpose_smooth_scale = 1 / tl.load( + transpose_smooth_scale_ptr + si + pid * T + tl.arange(0, T), + mask=indices < count, other=1e-30)[:, None] + + w = tl.load(weight_ptr + si + pid * T + tl.arange(0, T), + mask=indices < count)[:, None] + dw = tl.zeros((T,), dtype=tl.float32) + qdtype = transpose_dx_ptr.dtype.element_ty + for i in range(nb): + smooth_scale_1 = tl.load( + smooth_scale_ptr + eid * n * 2 + i * B + tl.arange(0, B)) + smooth_scale_2 = tl.load( + smooth_scale_ptr + eid * n * 2 + n + i * B + tl.arange(0, B)) + if not REVERSE: + smooth_scale_1 = 1 / smooth_scale_1 + smooth_scale_2 = 1 / smooth_scale_2 + + x1 = tl.load(x_ptr + offs, mask=indices[:, None] < count).to(tl.float32) + x2 = tl.load(x_ptr + offs + n, mask=indices[:, None] < count).to( + tl.float32) + g = tl.load(g_ptr + hoffs, mask=indices[:, None] < count).to(tl.float32) + sigmoid = 1 / (1 + tl.exp(-x1)) + dx1 = g * x2 * sigmoid * ( + 1 + x1 * (1 - sigmoid)) * w + dx2 = g * x1 * sigmoid * w + + dw += tl.sum(x1 * sigmoid * x2 * g, 1) + + t_dx = dx1 * transpose_smooth_scale + t_s = tl.maximum(tl.max(tl.abs(t_dx), 0) / 448, 1e-30) + if ROUND: + t_s = tl.exp2(tl.ceil(tl.log2(t_s))) + t_dx = t_dx / t_s + tl.store(transpose_dx_ptr + toffs, tl.trans(t_dx.to(qdtype)), + mask=indices[None, :] < round_count) + tl.store( + transpose_dx_scale_ptr + eid * max_block * n * 2 + pid * n * 2 + i * B + tl.arange( + 0, B), t_s) + + t_dx = dx2 * transpose_smooth_scale + t_s = tl.maximum(tl.max(tl.abs(t_dx), 0) / 448, 1e-30) + if ROUND: + t_s = tl.exp2(tl.ceil(tl.log2(t_s))) + t_dx = t_dx / t_s + tl.store(transpose_dx_ptr + round_count * n + toffs, + tl.trans(t_dx.to(qdtype)), mask=indices[None, :] < round_count) + tl.store( + transpose_dx_scale_ptr + eid * max_block * n * 2 + pid * n * 2 + n + i * B + tl.arange( + 0, B), t_s) + + dx1 = dx1 * smooth_scale_1 + dx2 = dx2 * smooth_scale_2 + maxs = tl.maximum( + tl.maximum(tl.max(dx1.abs(), 1), tl.max(dx2.abs(), 1)), maxs) + + offs += B + hoffs += B + toffs += B * round_count + + tl.store(dw_ptr + si + pid * T + tl.arange(0, T), dw, mask=indices < count) + scale = tl.maximum(maxs / 448, 1e-30) + if ROUND: + scale = tl.exp2(tl.ceil(tl.log2(scale))) + tl.store(dx_scale_ptr + si + pid * T + tl.arange(0, T), scale, + mask=indices < count) + + s = 1 / scale[:, None] + offs = si * n * 2 + pid * T * n * 2 + tl.arange(0, T)[:, + None] * n * 2 + tl.arange(0, B)[ + None, :] + hoffs = si * n + pid * T * n + tl.arange(0, T)[:, None] * n + tl.arange(0, + B)[ + None, + :] + for i in range(nb): + smooth_scale_1 = tl.load( + smooth_scale_ptr + eid * n * 2 + i * B + tl.arange(0, B)) + smooth_scale_2 = tl.load( + smooth_scale_ptr + eid * n * 2 + n + i * B + tl.arange(0, B)) + if not REVERSE: + smooth_scale_1 = 1 / smooth_scale_1 + smooth_scale_2 = 1 / smooth_scale_2 + + x1 = tl.load(x_ptr + offs, mask=indices[:, None] < count).to(tl.float32) + x2 = tl.load(x_ptr + offs + n, mask=indices[:, None] < count).to( + tl.float32) + g = tl.load(g_ptr + hoffs, mask=indices[:, None] < count).to(tl.float32) + sigmoid = 1 / (1 + tl.exp(-x1)) + dx1 = g * x2 * sigmoid * ( + 1 + x1 * (1 - sigmoid)) * smooth_scale_1 * w + dx2 = g * x1 * sigmoid * smooth_scale_2 * w + + dx1 = (dx1 * s).to(dx_ptr.dtype.element_ty) + dx2 = (dx2 * s).to(dx_ptr.dtype.element_ty) + + tl.store(dx_ptr + offs, dx1, mask=indices[:, None] < count) + tl.store(dx_ptr + n + offs, dx2, mask=indices[:, None] < count) + offs += B + hoffs += B + + +# requant multi-column quantized tensor +@triton.jit +def _batch_requant_kernel(x_ptr, scale_ptr, scales_ptr, + count_ptr, + N, + H: tl.constexpr, + W: tl.constexpr, + E: tl.constexpr + ): + eid = tl.program_id(axis=0) + rid = tl.program_id(axis=1) + cid = tl.program_id(axis=2) + max_block = tl.num_programs(axis=2) + + count = tl.load(count_ptr + eid) + round_count = tl.cdiv(count, 32) * 32 + if cid >= tl.cdiv(round_count, W): + return + + round_off = tl.sum(tl.where(tl.arange(0, E) < eid, + tl.cdiv(tl.load(count_ptr + tl.arange(0, E)), + 32) * 32, 0)) + + offs = round_off * N + rid * H * round_count + cid * W + tl.arange(0, H)[:, + None] * round_count + tl.arange( + 0, W)[None, :] + global_scale = tl.load(scale_ptr + eid * N + rid * H + tl.arange(0, H)) + # scales is stored with column-major format + local_scale = tl.load( + scales_ptr + max_block * N * eid + cid * N + rid * H + tl.arange(0, H)) + x = tl.load(x_ptr + offs).to(tl.float32) + rescale = local_scale / tl.maximum(global_scale, 1e-30) + x = x * rescale[:, None] + tl.store(x_ptr + offs, x) + + +# used in routed experts +def triton_batch_weighted_silu_and_smooth_quant_backward(g, x, weight, + counts, + smooth_scale=None, + transpose_smooth_scale=None, + splits=None, + reverse=True, + round_scale=False): + """ + + """ + assert round_scale + M, N = x.shape + n = N // 2 + n_expert = counts.shape[0] + assert N <= 8192 and 8192 % N == 0 + assert splits is not None, 'batch mode need splits to launch kernels' + + device = x.device + + accums = torch.cumsum(counts, 0) + + dx = torch.empty((M, N), device=device, dtype=torch.float8_e4m3fn) + + dx_scale = torch.empty((M,), device=device, dtype=torch.float32) + + dw = torch.empty_like(weight) + T = 32 + B = 32 + assert n % B == 0 and T == 32 + max_block = triton.cdiv(max(splits), T) + s = sum([(x + 31) // 32 for x in splits]) * 32 + transpose_dx = torch.empty((N * s,), device=device, + dtype=torch.float8_e4m3fn) + + if s == 0: + transpose_dx_scale = torch.zeros((n_expert, N), device=device, + dtype=torch.float32) + return dx, dx_scale, dw, transpose_dx, transpose_dx_scale + else: + transpose_dx_scales = torch.zeros((n_expert, max_block, N), + device=device, dtype=torch.bfloat16) + + grid = (n_expert, max_block) + batch_weighted_silu_and_smooth_quant_backward_kernel[grid]( + g, + x, + weight, + smooth_scale, + transpose_smooth_scale, + counts, + accums, + dx, + dx_scale, + transpose_dx, + transpose_dx_scales, + dw, + n, + T, + B, + n_expert, + reverse, + round_scale, + num_stages=5, + num_warps=4 + ) + transpose_dx_scale = transpose_dx_scales.amax(1).float() + grid = (n_expert, N // B, max_block) + _batch_requant_kernel[grid](transpose_dx, transpose_dx_scale, + transpose_dx_scales, + counts, + N, + B, + T, + n_expert, + num_stages=3, + num_warps=2) + + return dx, dx_scale, dw, transpose_dx, transpose_dx_scale + diff --git a/linghe/utils/transpose.py b/linghe/utils/transpose.py index 4392fb3..bfc1fcf 100644 --- a/linghe/utils/transpose.py +++ b/linghe/utils/transpose.py @@ -4,7 +4,7 @@ """ import itertools - +from typing import Optional import torch import triton import triton.language as tl @@ -15,54 +15,6 @@ # os.environ["TRITON_PRINT_AUTOTUNING"] = "1" -@triton.jit -def deprecated_transpose_kernel(x_ptr, t_ptr, M, N, H: tl.constexpr, - W: tl.constexpr, EVEN: tl.constexpr): - pid = tl.program_id(axis=0) - # col-wise read, row-wise write - offs = pid * W + tl.arange(0, H)[:, None] * N + tl.arange(0, W)[None, :] - toffs = pid * W * M + tl.arange(0, W)[:, None] * M + tl.arange(0, H)[None, - :] - m = tl.cdiv(M, H) - for i in range(m): - if EVEN: - y = tl.trans(tl.load(x_ptr + offs)) - tl.store(t_ptr + toffs, y) - else: - y = tl.trans(tl.load(x_ptr + offs, mask=(pid * W + tl.arange(0, W)[ - None, :] < N) & ( - i * H + tl.arange( - 0, H)[:, - None] < M))) - tl.store(t_ptr + toffs, y, - mask=(pid * W + tl.arange(0, W)[:, None] < N) & ( - i * H + tl.arange(0, H)[None, :] < M)) - offs += H * N - toffs += H - - -def triton_depracated_transpose(x): - M, N = x.shape - device = x.device - t = torch.empty((N, M), device=device, dtype=x.dtype) - - H = 512 - W = 32 if x.dtype.itemsize == 1 else 16 - EVEN = M % H == 0 and N % W == 0 - num_stages = 3 - num_warps = 8 - - grid = (triton.cdiv(N, W),) - deprecated_transpose_kernel[grid]( - x, t, - M, N, - H, W, - EVEN, - num_stages=num_stages, - num_warps=num_warps - ) - return t - @triton.jit def transpose_kernel(x_ptr, t_ptr, M, N, H: tl.constexpr, W: tl.constexpr, @@ -99,7 +51,19 @@ def transpose_dim_0_1_kernel(x_ptr, t_ptr, B, M, b_stride, m_stride, tl.store(t_ptr + toffs, y) -def triton_transpose(x, dim0=None, dim1=None): +def triton_transpose(x: torch.Tensor, + dim0: Optional[int] = None, + dim1: Optional[int] = None): + """ + transpose x with dim0 and dim1 + Args: + x: input tensor + dim0: dim 0 + dim1: dim 1 + + Returns: + transposed tensor + """ shape = x.shape rank = len(shape) assert rank <= 4 @@ -180,13 +144,19 @@ def transpose_and_pad_kernel(x_ptr, t_ptr, mask=(rid * H + tl.arange(0, H)[None, :] < P)) -""" -pad: M will be padded to mutiplier of 32 -M is usually less than N without deepep -""" - def triton_transpose_and_pad(x, out=None, pad=True): + """ + transpose x and padding the column size to be mutiplier of 32, + it is used for calculated gradient of weight with torch._scaled__mm + Args: + x: input tensor + out: + pad: whether need padding + + Returns: + out: output tensor + """ # fat block, shape:[H,W] M, N = x.shape P = round_up(M, b=32) if pad else M @@ -228,14 +198,14 @@ def batch_transpose_kernel(xs_ptr, xts_ptr, M, N, H: tl.constexpr, toffs += H -""" -x: [M, N]*expert -x_t: [N,M]*expert -""" - - def triton_batch_transpose(xs, xts=None): - # block shape:[H,W] + """ + batch transpose x + Args: + xs: input tensor list, [M, N]*expert + Returns: + xts: output tensor list, [N,M]*expert + """ M, N = xs[0].shape n_experts = len(xs) if xts is None: @@ -286,16 +256,18 @@ def batch_transpose_and_pad_kernel(x_ptr, t_ptr, count_ptr, accum_ptr, toffs += H -""" -pad: M will be padded to mutiplier of 32 -padding should be filled with 0 -M is usually less than N -x: [sum(bs), N] -x_t: [sum(pad(bs)*N)] -""" - - def triton_batch_transpose_and_pad(x, count_list, x_t=None, pad=True): + """ + transpose and pad each tensor stored in x + Args: + x: [sum(bs), N] + count_list: a python list of token count + pad: whether pad to mutiplier of 32, + padding value should be filled with 0 if padded + + Returns: + x_t: output tensor + """ assert pad # block shape:[H,W] M, N = x.shape diff --git a/setup.py b/setup.py index cd094eb..bc21ae2 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,7 @@ license="MIT", license_files=("LICENSE",), description="LLM traning kernels", - URL="https://code.alipay.com/pia/linghe", + URL="https://github.com/inclusionAI/linghe", packages=find_packages(), install_requires=[], python_requires=">=3.8", diff --git a/tests/test_channel_quant.py b/tests/test_channel_quant.py new file mode 100644 index 0000000..5b4fd49 --- /dev/null +++ b/tests/test_channel_quant.py @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- +""" +Copyright (c) Ant Financial Service Group and its affiliates. +""" + +import torch + +from linghe.quant.channel import (triton_deprecated_tokenwise_row_quant, + triton_row_quant, + triton_tokenwise_row_quant) +from linghe.tools.benchmark import benchmark_func +from linghe.tools.util import (output_check, + torch_row_quant) + + +def test_row_quant(M=4096, N=4096, round_scale=True, bench=False): + device = 'cuda:0' + dtype = torch.bfloat16 + x = torch.randn((M, N), dtype=dtype, device=device) ** 3 + + x_q_ref, x_scale_ref = torch_row_quant(x, round_scale=round_scale) + + x_q, x_scale = triton_row_quant(x, round_scale=round_scale) + output_check(x_q_ref.float(), x_q.float(), mode='data') + output_check(x_scale_ref, x_scale, mode='scale') + + x_q, x_scale = triton_tokenwise_row_quant(x, round_scale=round_scale) + output_check(x_q_ref.float(), x_q.float(), mode='data') + output_check(x_scale_ref, x_scale, mode='scale') + + if bench: + ref_time = benchmark_func(torch_row_quant, x, n_repeat=100, + ref_bytes=M * N * 3) + benchmark_func(triton_row_quant, x, n_repeat=100, ref_bytes=M * N * 3, + ref_time=ref_time) + benchmark_func(triton_deprecated_tokenwise_row_quant, x, n_repeat=100, + ref_bytes=M * N * 3, ref_time=ref_time) + benchmark_func(triton_tokenwise_row_quant, x, n_repeat=100, + ref_bytes=M * N * 3, ref_time=ref_time) + + +if __name__ == '__main__': + test_row_quant(M=4096, N=4096, round_scale=False) + test_row_quant(M=4090, N=4096, round_scale=True) + test_row_quant(M=4096, N=8192, round_scale=True) + test_row_quant(M=3456, N=2048, round_scale=True) + test_row_quant(M=1, N=2048, round_scale=True) diff --git a/tests/test_dot.py b/tests/test_dot.py index 699fe1a..fcedbd6 100644 --- a/tests/test_dot.py +++ b/tests/test_dot.py @@ -7,8 +7,7 @@ from linghe.tools.benchmark import benchmark_func from linghe.tools.util import output_check -from linghe.utils.dot import (triton_dot, - triton_mix_precise_dot) +from linghe.utils.dot import triton_dot def torch_fp16_dot(x, y): @@ -34,13 +33,9 @@ def test_dot(M=4096, N=4096, bench=False): sums_ref = (x.float() * ( q.to(torch.float32) * quant_scale[:, None] * smooth_scale[None, :])).sum(dim=1) - sums = triton_mix_precise_dot(x, q, smooth_scale, quant_scale, reverse=True) - output_check(sums_ref, sums.float(), 'sum') if bench: ref_time = benchmark_func(torch_fp16_dot, x, y, n_repeat=n_repeat) - benchmark_func(triton_mix_precise_dot, x, q, smooth_scale, quant_scale, - reverse=True, n_repeat=n_repeat, ref_time=ref_time) if __name__ == '__main__': diff --git a/tests/test_gemm.py b/tests/test_fp32_gemm.py similarity index 85% rename from tests/test_gemm.py rename to tests/test_fp32_gemm.py index 68fdcf9..3f60d18 100644 --- a/tests/test_gemm.py +++ b/tests/test_fp32_gemm.py @@ -14,6 +14,17 @@ from linghe.tools.util import output_check + +def torch_fp32_matmul(x, w): + return torch.nn.functional.linear(x.float(), w.float()) + +def torch_fp32_matmul_backward(dy, w): + return (dy @ w).to(torch.bfloat16) + +def torch_fp32_matmul_update(y, x): + return (y.t() @ x).to(torch.bfloat16) + + def test_fp32_matmul(M=2048, N=256, K=8192, bench=False): # M, N, K = 4096, 256, 8192 dtype = torch.bfloat16 @@ -25,15 +36,6 @@ def test_fp32_matmul(M=2048, N=256, K=8192, bench=False): scale = torch.randn(M, dtype=torch.float32, device=device) dy = torch.randn(M, N, dtype=torch.float32, device=device) - def torch_fp32_matmul(x, w): - return torch.nn.functional.linear(x.float(), w.float()) - - def torch_fp32_matmul_backward(dy, w): - return (dy @ w).to(torch.bfloat16) - - def torch_fp32_matmul_update(y, x): - return (y.t() @ x).to(torch.bfloat16) - y_ref = torch_fp32_matmul(x, w) y = triton_fp32_gemm(x, w) output_check(y_ref, y.float(), mode='fp32_gemm') @@ -43,10 +45,9 @@ def torch_fp32_matmul_update(y, x): output_check(y_ref, y.float(), mode='scaled_fp32_gemm') dx = torch.zeros(M, K, dtype=dtype, device=device) - dx_clone = dx.clone() - triton_fp32_gemm_for_backward(dy, w, dx_clone, accum=True) - dx_ref = dy @ w.float() + dx.float() - output_check(dx_ref, dx_clone.float(), mode='backward') + dx = triton_fp32_gemm_for_backward(dy, w) + dx_ref = dy @ w.float() + output_check(dx_ref, dx.float(), mode='backward') main_grad = triton_fp32_gemm_for_update(y, x) main_grad_ref = y.t() @ (x.float()) @@ -72,8 +73,8 @@ def torch_fp32_matmul_update(y, x): n_repeat=n_repeat, ref_bytes=M * K * 10 + N * K * 4 + M * N * 4, ref_linghe=2 * M * N * K) - benchmark_func(triton_fp32_gemm_for_backward, dy, w, dx_clone, - accum=True, n_repeat=n_repeat, + benchmark_func(triton_fp32_gemm_for_backward, dy, w, + n_repeat=n_repeat, ref_bytes=M * K * 2 + N * K * 2 + M * N * 4, ref_linghe=2 * M * N * K, ref_time=ref_time) diff --git a/tests/test_gather.py b/tests/test_gather.py index bbbb58b..4606284 100644 --- a/tests/test_gather.py +++ b/tests/test_gather.py @@ -5,46 +5,48 @@ import torch -from linghe.tools.benchmark import benchmark_func -from linghe.tools.util import (output_check, - torch_make_indices) from linghe.utils.gather import (triton_make_row_id_map, triton_make_row_id_map_and_indices, triton_index_select, - triton_permute_with_mask_map) + triton_permute_with_mask_map, + triton_smooth_permute_with_indices, + triton_smooth_permute_with_mask_map, + triton_smooth_weighted_permute_with_indices, + triton_batch_transpose_smooth_permute_with_indices) +from linghe.tools.util import (output_check, + torch_batch_smooth_quant, + torch_make_indices, + torch_smooth_quant) +from linghe.tools.benchmark import benchmark_func def torch_index_select(y, indices): output = y.index_select(0, indices) return output - def torch_select_with_padded_map_mask(y, mask_map, out_tokens): E = mask_map.shape[1] if y.ndim > 1: - output = torch.zeros((out_tokens, y.shape[1]), dtype=y.dtype, - device=y.device) + output = torch.zeros((out_tokens, y.shape[1]), dtype=y.dtype, device=y.device) else: - output = torch.zeros((out_tokens,), dtype=y.dtype, device=y.device) + output = torch.zeros((out_tokens, ), dtype=y.dtype, device=y.device) for i in range(E): - indices = mask_map[:, i] - src_idx = torch.nonzero(indices > -1) + indices = mask_map[:,i] + src_idx = torch.nonzero(indices>-1) dst_idx = indices[src_idx] output[dst_idx] = y[src_idx] return output - def torch_ravel_with_padded_map_mask(y, mask_map, out_tokens): E = mask_map.shape[1] - output = torch.zeros((out_tokens,), dtype=y.dtype, device=y.device) + output = torch.zeros((out_tokens, ), dtype=y.dtype, device=y.device) for i in range(E): - indices = mask_map[:, i] - src_idx = torch.nonzero(indices > -1) + indices = mask_map[:,i] + src_idx = torch.nonzero(indices>-1) dst_idx = indices[src_idx] - output[dst_idx] = y[src_idx, i] + output[dst_idx] = y[src_idx,i] return output - def torch_fp16_index_select(x, scales, indices): return x.index_select(0, indices), scales.index_select(0, indices) @@ -53,6 +55,78 @@ def torch_scatter(logits, routing_map, weights): logits[routing_map] = weights +# optional dequant and smooth and quant +def torch_smooth_permute_with_indices(grad_data, grad_scale, indices, + smooth_scales, + token_count_per_expert_list, + round_scale=True): + M, N = grad_data.shape + if grad_scale is not None: + B = grad_data.shape[1] // ( + 1 if grad_scale.ndim == 1 else grad_scale.shape[1]) + q_refs = [] + scale_refs = [] + s = 0 + for i, c in enumerate(token_count_per_expert_list): + c = token_count_per_expert_list[i] + data_slice = grad_data.view(torch.uint8)[indices[s:s + c]].view( + torch.float8_e4m3fn) + if grad_scale is not None: + scale_slice = grad_scale[indices[s:s + c]] + y_smooth = (data_slice.float().view(c, N // B, B) * scale_slice[:, :, + None]).view(c, N) / \ + smooth_scales[i] + else: + y_smooth = data_slice.float() / smooth_scales[i] + scale = y_smooth.abs().amax(1) / 448 + if round_scale: + scale = torch.exp2(torch.ceil(torch.log2(scale))) + scale_refs.append(scale) + q = (y_smooth / scale[:, None]).to(torch.float8_e4m3fn) + q_refs.append(q.view(torch.uint8)) + s += c + q_ref = torch.cat(q_refs, 0).view(torch.float8_e4m3fn) + scale_ref = torch.cat(scale_refs, 0) + + return q_ref, scale_ref + + + +# desmooth,dequant, gather, pad, transpose, smooth, quant +def torch_batch_transpose_smooth_permute_with_indices(x_q, x_scale, org_smooth_scale, smooth_scales, + indices, + token_count_per_expert_list, + round_scale=True): + M, DIM = x_q.shape + q_refs = [] + scale_refs = [] + s = 0 + for i, c in enumerate(token_count_per_expert_list): + c = token_count_per_expert_list[i] + if c == 0: + y_scale = torch.zeros((DIM,), dtype=torch.float32, device=x_q.device) + scale_refs.append(y_scale.view(-1)) + continue + N = (c + 31)//32 * 32 + data_slice = x_q[indices[s:s + c]] + if x_scale is not None: + scale_slice = x_scale[indices[s:s + c]] + y = data_slice.float() * scale_slice[:, None] * org_smooth_scale + else: + y = data_slice.float() + smooth_scale = smooth_scales[s:s+c] + if N > c: + y = torch.nn.functional.pad(y, (0,0,0, N-c)) + smooth_scale = torch.nn.functional.pad(smooth_scale, (0, N-c)) + y_q, y_scale, y_max= torch_smooth_quant(y.t().contiguous(), smooth_scale, reverse=True, round_scale=round_scale) + scale_refs.append(y_scale.view(-1)) + q_refs.append(y_q.view(-1)) + s += c + q_ref = torch.cat(q_refs, 0) + scale_ref = torch.stack(scale_refs, 0) + return q_ref, scale_ref + + def test_make_id_map(M=4098, n_experts=32, topk=2, bias=0.0, bench=False): dtype = torch.bfloat16 device = 'cuda:0' @@ -64,6 +138,7 @@ def test_make_id_map(M=4098, n_experts=32, topk=2, bias=0.0, bench=False): token_count_per_expert_list = token_count_per_expert.tolist() out_tokens = sum(token_count_per_expert_list) + row_id_map_output = triton_make_row_id_map(mask_map) assert (row_id_map - row_id_map_output).abs().sum().item() == 0 @@ -71,6 +146,45 @@ def test_make_id_map(M=4098, n_experts=32, topk=2, bias=0.0, bench=False): assert (row_id_indices - indices).abs().sum().item() == 0 + +def test_triton_smooth_weighted_permute_with_indices(M=4096, N=4096, + n_experts=256, + topk=8, + round_scale=True, + bench=False): + device = 'cuda:0' + reverse = True + y = torch.randn((M, N), dtype=torch.bfloat16, device=device) + logits = torch.randn((M, n_experts), dtype=torch.float32, device=device) + smooth_scales = 1 + 10 * torch.rand((n_experts, N), device=device, + dtype=torch.float32) + probs, mask_map, token_count_per_expert, indices, row_id_map = torch_make_indices( + logits, topk=topk, bias=0.0) + + tokens = torch.randn((indices.shape[0], N), dtype=torch.bfloat16, + device=device) + y_q, y_scale, y_sum = triton_smooth_weighted_permute_with_indices( + y, tokens, smooth_scales, token_count_per_expert, indices, x_q=None, + x_scale=None, reverse=reverse, round_scale=round_scale) + + y_q_ref, y_scale_ref = torch_batch_smooth_quant(y, smooth_scales, indices, + token_count_per_expert, + reverse=reverse, + round_scale=round_scale) + sum_ref = (tokens * y[indices]).sum(1) + + output_check(y_q_ref.float(), y_q.float(), 'data') + output_check(y_scale_ref.float(), y_scale.float(), 'scale') + output_check(sum_ref.float(), y_sum.float(), 'sum') + + if bench: + n_repeat = 100 + benchmark_func(triton_smooth_weighted_permute_with_indices, + y, tokens, smooth_scales, token_count_per_expert, + indices, reverse=reverse, round_scale=round_scale, + n_repeat=n_repeat) + + def test_triton_permute_with_mask_map(M=4096, N=4096, n_experts=256, topk=8, bench=False): device = 'cuda:0' @@ -100,19 +214,15 @@ def test_triton_permute_with_mask_map(M=4096, N=4096, n_experts=256, topk=8, output_check(scale_out_ref, scale_out, 'scale_out') output_check(probs_out_ref, probs_out, 'prob_out') - nzs = torch.sum(row_id_map >= 0, 0) - bias = torch.cumsum((nzs + 15) // 16 * 16 - nzs, 0) + nzs = torch.sum(row_id_map>=0, 0) + bias = torch.cumsum((nzs + 15)//16*16 - nzs, 0) row_id_map_clone = row_id_map.clone().detach() row_id_map_clone[:, 1:] += bias[:-1] - round_row_id_map = torch.where(row_id_map >= 0, row_id_map_clone, -1) - padded_out_tokens = sum( - [(x + 15) // 16 * 16 for x in token_count_per_expert.tolist()]) - x_out_ref = torch_select_with_padded_map_mask(x, round_row_id_map, - padded_out_tokens) - scale_out_ref = torch_select_with_padded_map_mask(scales, round_row_id_map, - padded_out_tokens) - prob_out_ref = torch_ravel_with_padded_map_mask(probs, round_row_id_map, - padded_out_tokens) + round_row_id_map = torch.where(row_id_map>=0, row_id_map_clone, -1) + padded_out_tokens = sum([(x+15)//16*16 for x in token_count_per_expert.tolist()]) + x_out_ref = torch_select_with_padded_map_mask(x, round_row_id_map, padded_out_tokens) + scale_out_ref = torch_select_with_padded_map_mask(scales, round_row_id_map, padded_out_tokens) + prob_out_ref = torch_ravel_with_padded_map_mask(probs, round_row_id_map, padded_out_tokens) x_out, scale_out, probs_out = triton_permute_with_mask_map(x, scales, probs, round_row_id_map, padded_out_tokens, @@ -128,11 +238,9 @@ def test_triton_permute_with_mask_map(M=4096, N=4096, n_experts=256, topk=8, ref_time = benchmark_func(torch_fp16_index_select, x, scales, indices, n_repeat=n_repeat, ref_bytes=ref_bytes) benchmark_func(triton_index_select, x, indices, scale=scales, - n_repeat=n_repeat, ref_time=ref_time, - ref_bytes=ref_bytes) + n_repeat=n_repeat, ref_time=ref_time, ref_bytes=ref_bytes) benchmark_func(triton_permute_with_mask_map, x, scales, probs, - row_id_map, out_tokens, contiguous=True, - n_repeat=n_repeat, + row_id_map, out_tokens, contiguous=True, n_repeat=n_repeat, ref_time=ref_time, ref_bytes=ref_bytes) benchmark_func(triton_permute_with_mask_map, x, scales, probs, row_id_map, out_tokens, contiguous=False, @@ -141,11 +249,156 @@ def test_triton_permute_with_mask_map(M=4096, N=4096, n_experts=256, topk=8, ref_time=ref_time, ref_bytes=ref_bytes) +def test_triton_smooth_permute_with_mask_map(M=4096, N=4096, n_experts=32, + topk=8, round_scale=True, + bench=False): + device = 'cuda:0' + dtype = torch.bfloat16 + smooth_scales = 1 + 10 * torch.rand((n_experts, N), device=device, + dtype=torch.float32) + logits = torch.randn((M, n_experts), dtype=torch.float32, + device=device) ** 3 + probs, mask_map, token_count_per_expert, indices, row_id_map = torch_make_indices( + logits, topk=topk, bias=-0.01) + + token_count_per_expert_list = token_count_per_expert.tolist() + out_tokens = sum(token_count_per_expert_list) + + B = 128 + grad_data = torch.randn((M, N), dtype=torch.bfloat16, device=device).to( + torch.float8_e4m3fn) + grad_scale = 1 + torch.rand((M, N // B), dtype=torch.float32, device=device) + q_ref, scale_ref = torch_smooth_permute_with_indices(grad_data, grad_scale, + indices, smooth_scales, + token_count_per_expert_list, + round_scale=round_scale) + y_q, y_scale = triton_smooth_permute_with_indices(grad_data, + grad_scale, + smooth_scales, + token_count_per_expert, + indices, + x_q=None, + x_scale=None, + reverse=False, + round_scale=round_scale) + output_check(q_ref.float(), y_q.float(), 'data') + output_check(scale_ref.float(), y_scale.float(), 'scale') + + + + # smooth_scale_ptrs = torch.tensor([x.data_ptr() for x in torch.split(smooth_scales,1)], device=device) + permuted_data, permuted_scale = triton_smooth_permute_with_mask_map( + grad_data, row_id_map, grad_scale, M, n_experts, out_tokens, N, + smooth_scales, reverse=False, round_scale=round_scale) + output_check(q_ref.float(), permuted_data.float(), 'smoothed.data') + output_check(scale_ref.float(), permuted_scale.float(), 'smoothed.scale') + + q_ref, scale_ref = torch_smooth_permute_with_indices(grad_data, None, + indices, smooth_scales, + token_count_per_expert_list, + round_scale=round_scale) + permuted_data, permuted_scale = triton_smooth_permute_with_mask_map( + grad_data, row_id_map, None, M, n_experts, out_tokens, N, + smooth_scales, reverse=False, round_scale=round_scale) + output_check(q_ref.float(), permuted_data.float(), 'smoothed.data') + output_check(scale_ref.float(), permuted_scale.float(), 'smoothed.scale') + + + + if bench: + benchmark_func(triton_smooth_permute_with_indices, grad_data, + grad_scale, smooth_scales, token_count_per_expert, + indices, round_scale=round_scale, n_repeat=100, + ref_bytes=out_tokens * N * 2) + benchmark_func(triton_smooth_permute_with_mask_map, grad_data, + row_id_map, grad_scale, M, n_experts, out_tokens, N, + smooth_scales, reverse=False, round_scale=round_scale, + n_repeat=100, ref_bytes=out_tokens * N * 2) + + + + +def test_triton_batch_transpose_smooth_permute_with_indices(M=1024, N=2048, n_experts=32, topk=8, bench=False): + + device = 'cuda:0' + if True: + logits = torch.randn((M, n_experts), dtype=torch.float32, + device=device) ** 3 + logits[:,0] -= 1000 + logits[:,2] -= 100 + probs, mask_map, token_count_per_expert, indices, row_id_map = torch_make_indices( + logits, topk=topk, bias=-0.01) + + token_count_per_expert_list = token_count_per_expert.tolist() + out_tokens = sum(token_count_per_expert_list) + + x = torch.randn((M, N), dtype=torch.bfloat16, device=device).to( + torch.float8_e4m3fn) + scale = torch.rand((M,), dtype=torch.float32, device=device) + 0.1 + org_smooth_scale = torch.rand((N,), dtype=torch.float32, device=device) + 0.1 + smooth_scales = torch.rand((out_tokens, ), dtype=torch.float32, device=device) + 0.1 + else: + # torch.save({"x":x, "scale":scale, "org_smooth_scale":org_smooth_scale,"smooth_scales":smooth_scales, "indices":indices, "token_count_per_expert":token_count_per_expert,"splits":splits}, '/tmp/debug.bin') + state = torch.load('/tmp/debug.bin') + x = state['x'] + scale = state['scale'] + org_smooth_scale = state['org_smooth_scale'] + smooth_scales = state['smooth_scales'] + indices = state['indices'] + token_count_per_expert = state['token_count_per_expert'] + token_count_per_expert_list = state['splits'] + out_tokens = sum(token_count_per_expert_list) + + + x_q_ref, x_scale_ref = torch_batch_transpose_smooth_permute_with_indices(x, scale, org_smooth_scale, smooth_scales, + indices, + token_count_per_expert_list, + round_scale=True) + + x_q, x_scale = triton_batch_transpose_smooth_permute_with_indices(x, scale, org_smooth_scale, smooth_scales, + indices, + token_count_per_expert, token_count_per_expert_list, + round_scale=True) + output_check(x_q_ref.float(), x_q.float(), 'smoothed.data') + output_check(x_scale_ref.float(), x_scale.float(), 'smoothed.scale') + + + x_q_ref, x_scale_ref = torch_batch_transpose_smooth_permute_with_indices(x, None, None, smooth_scales, + indices, + token_count_per_expert_list, + round_scale=True) + + x_q, x_scale = triton_batch_transpose_smooth_permute_with_indices(x, None, None, smooth_scales, + indices, + token_count_per_expert, token_count_per_expert_list, + round_scale=True) + output_check(x_q_ref.float(), x_q.float(), 'bf16.data') + output_check(x_scale_ref.float(), x_scale.float(), 'bf16.scale') + + + if bench: + benchmark_func(torch_batch_transpose_smooth_permute_with_indices, x, scale, org_smooth_scale, smooth_scales, + indices, + token_count_per_expert_list, + round_scale=True, + ref_bytes=out_tokens * N * 2) + benchmark_func(triton_batch_transpose_smooth_permute_with_indices, x, scale, org_smooth_scale, smooth_scales, + indices, + token_count_per_expert, token_count_per_expert_list, + round_scale=True, + ref_bytes=out_tokens * N * 2) + + if __name__ == '__main__': test_make_id_map(M=4098, n_experts=32, topk=2, bias=0.0, bench=False) - test_triton_permute_with_mask_map(M=16384, N=2048, n_experts=32, topk=8, - bench=False) - test_triton_permute_with_mask_map(M=8192, N=4096, n_experts=32, topk=8, - bench=False) - test_triton_permute_with_mask_map(M=7628, N=2048, n_experts=32, topk=8, - bench=False) + test_triton_permute_with_mask_map(M=16384, N=2048, n_experts=32, topk=8, bench=False) + test_triton_permute_with_mask_map(M=8192, N=4096, n_experts=32, topk=8, bench=False) + test_triton_permute_with_mask_map(M=7628, N=2048, n_experts=32, topk=8, bench=False) + + test_triton_smooth_permute_with_mask_map(M=4096, N=4096, n_experts=32, + topk=8) + test_triton_smooth_permute_with_mask_map(M=7628, N=2048, n_experts=32, + topk=8) + + test_triton_batch_transpose_smooth_permute_with_indices(M=16384, N=2048, n_experts=32, topk=2, bench=False) + test_triton_batch_transpose_smooth_permute_with_indices(M=8192, N=4096, n_experts=32, topk=2, bench=False) diff --git a/tests/test_group_quant.py b/tests/test_group_quant.py index c8c2af8..d8c68ec 100644 --- a/tests/test_group_quant.py +++ b/tests/test_group_quant.py @@ -5,8 +5,7 @@ import torch -from linghe.quant.block.group import (triton_group_quant, - triton_persist_group_quant) +from linghe.quant.group import triton_group_quant from linghe.tools.benchmark import benchmark_func from linghe.tools.util import (output_check, torch_group_quant) @@ -19,21 +18,10 @@ def test_group_quant(M=4096, N=4096, B=128, round_scale=False, bench=False): output_check(xq_ref.float(), xq.float(), mode='data') output_check(x_scale_ref.float(), x_scale.float(), mode='scale') - xq, x_scale = triton_persist_group_quant(x, group_size=B, - round_scale=round_scale) - output_check(xq_ref.float(), xq.float(), mode='data') - output_check(x_scale_ref.float(), x_scale.float(), mode='scale') - - # torch.testing.assert_close(xq_ref.float(), xq.float(), rtol=0.02, atol=0.02) - if bench: n_repeat = 100 - ref_time = benchmark_func(triton_group_quant, x, group_size=B, + benchmark_func(triton_group_quant, x, group_size=B, n_repeat=n_repeat, ref_bytes=M * N * 3) - benchmark_func(triton_persist_group_quant, x, group_size=B, - n_repeat=n_repeat, ref_time=ref_time, - ref_bytes=M * N * 3) - if __name__ == '__main__': test_group_quant(M=4096, N=4096, B=128) diff --git a/tests/test_hadamard_quant.py b/tests/test_hadamard_quant.py new file mode 100644 index 0000000..c4c6a57 --- /dev/null +++ b/tests/test_hadamard_quant.py @@ -0,0 +1,78 @@ +# -*- coding: utf-8 -*- +""" +Copyright (c) Ant Financial Service Group and its affiliates. +""" + +import torch + +from linghe.quant.hadamard import triton_hadamard_quant +from linghe.tools.benchmark import benchmark_func +from linghe.tools.util import (output_check, + make_hadamard_matrix, + torch_hadamard_transform, + torch_row_quant, + ) + + + + +# apply hadamard transformation and quantization for x +def torch_hadamard_quant(x, hm, round_scale=False): + xh = torch_hadamard_transform(x, hm, side='right') + q, s = torch_row_quant(xh, round_scale=round_scale) + xht = torch_hadamard_transform(x.t().contiguous(), hm, side='right') + qt, st = torch_row_quant(xht, round_scale=round_scale) + + return xh,xht,q,s,qt,st + + +def test_hadamard_quant(M=8192, N=1024, K=2048, B=64, bench=False): + dtype = torch.bfloat16 + device = 'cuda:0' + x = torch.randn((M, K), dtype=dtype, device=device) + w = torch.randn((N, K), dtype=dtype, device=device) + dy = torch.randn((M, N), dtype=dtype, device=device) + + hm = make_hadamard_matrix(B, dtype=dtype, device=device, norm=True) + + + y_ref = x@w.t() + dx_ref = dy@w + dw_ref = dy.t()@x + + xh,xht,xq,xs,xqt,xst = torch_hadamard_quant(x, hm, round_scale=False) + wh,wht,wq,ws,wqt,wst = torch_hadamard_quant(w, hm, round_scale=False) + dyh,dyht,dyq,dys,dyqt,dyst = torch_hadamard_quant(dy, hm, round_scale=False) + + y = xh@wh.t() + dx = dyh@wht.t() + dw = dyht@xht.t() + + output_check(y_ref,y,'bf16.y') + output_check(dx_ref,dx,'bf16.dx') + output_check(dw_ref,dw,'bf16.dw') + + x_q, x_scale, xt_q, xt_scale = triton_hadamard_quant(x, hm) + output_check(xq, x_q, 'x.data') + output_check(xs, x_scale, 'x.scale') + output_check(xqt, xt_q, 'xt.data') + output_check(xst, xt_scale, 'xt.scale') + + + w_q, w_scale, wt_q, wt_scale = triton_hadamard_quant(w, hm) + output_check(wq, w_q, 'w.data') + output_check(ws, w_scale, 'w.scale') + output_check(wqt, wt_q, 'wt.data') + output_check(wst, wt_scale, 'wt.scale') + + + dy_q, dy_scale, dyt_q, dyt_scale = triton_hadamard_quant(dy, hm) + output_check(dyq, dy_q, 'dy.data') + output_check(dys, dy_scale, 'dy.scale') + output_check(dyqt, dyt_q, 'dyt.data') + output_check(dyst, dyt_scale, 'dyt.scale') + + + +if __name__ == '__main__': + test_hadamard_quant(M=8192, N=1024, K=2048, B=64, bench=False) \ No newline at end of file diff --git a/tests/test_norm.py b/tests/test_norm.py index 4ed2f8e..fb45261 100644 --- a/tests/test_norm.py +++ b/tests/test_norm.py @@ -6,14 +6,16 @@ import torch import torch.nn.functional as F -from linghe.tools.benchmark import benchmark_func +from linghe.utils.norm import (triton_rms_norm_and_smooth_quant_forward, + triton_rms_norm_and_block_quant_forward, + triton_rms_norm_backward, + triton_rms_norm_forward, + triton_group_norm_gate_forward, + triton_group_norm_gate_backward) from linghe.tools.util import (output_check, - torch_group_quant) -from linghe.utils.norm import (triton_rms_norm_and_block_quant_forward, - triton_rms_norm_backward, - triton_rms_norm_forward, - triton_group_norm_gate_forward, - triton_group_norm_gate_backward) + torch_smooth_quant, + torch_group_quant) +from linghe.tools.benchmark import benchmark_func def torch_rms_forward(x, weight): @@ -50,12 +52,11 @@ def torch_rms_backward(x, weight, dy): return x.grad, rmsnorm.weight.grad -def torch_rms_and_quant_forward(x, weight, smooth_scale=None, - round_scale=False): +def torch_rms_and_smooth_quant_forward(x, weight, smooth_scale=None, + round_scale=False): x = x.float() weight = weight.float() - if smooth_scale is not None: - smooth_scale = smooth_scale.float() + smooth_scale = smooth_scale.float() N = x.shape[-1] rmsnorm = torch.nn.RMSNorm( normalized_shape=N, @@ -66,20 +67,15 @@ def torch_rms_and_quant_forward(x, weight, smooth_scale=None, with torch.no_grad(): rmsnorm.weight.copy_(weight) y = rmsnorm(x) - # blockwise - y_q, y_scale = torch_group_quant(y, round_scale=round_scale) - yt_q, yt_scale = torch_group_quant(y.t(), round_scale=round_scale) - return y_q, y_scale, yt_q, yt_scale + # smooth + y_q, y_scale, y_maxs = torch_smooth_quant(y, smooth_scale, reverse=False, + round_scale=round_scale) + return y_q, y_scale, y_maxs -# backward of rms is bf16, do not need quant -def torch_rms_and_quant_backward(x, weight, dy, smooth_scale=None, - round_scale=False): +def torch_rms_and_block_quant_forward(x, weight, round_scale=False): x = x.float() weight = weight.float() - dy = dy.float() - if smooth_scale is not None: - smooth_scale = smooth_scale.float() N = x.shape[-1] rmsnorm = torch.nn.RMSNorm( normalized_shape=N, @@ -89,15 +85,11 @@ def torch_rms_and_quant_backward(x, weight, dy, smooth_scale=None, ) with torch.no_grad(): rmsnorm.weight.copy_(weight) - x = x.clone().detach().requires_grad_() y = rmsnorm(x) - y.backward(gradient=dy) - dx = x.grad - dw = rmsnorm.weight.grad - dx_q, dx_scale = torch_group_quant(dx, round_scale=round_scale) - dxt_q, dxt_scale = torch_group_quant(dx.t(), round_scale=round_scale) - - return dx_q, dx_scale, dw, dxt_q, dxt_scale + # blockwise + y_q, y_scale = torch_group_quant(y, round_scale=round_scale) + yt_q, yt_scale = torch_group_quant(y.t(), round_scale=round_scale) + return y_q, y_scale, yt_q, yt_scale @torch.compile @@ -153,6 +145,40 @@ def test_rmsnorm(M=4096, N=4096, bench=False): ref_bytes=M * N * 3) +def test_rmsnorm_and_smooth_quant(M=4096, N=4096, bench=False): + dtype = torch.bfloat16 + device = 'cuda:0' + + x = torch.randn(M, N, dtype=dtype, requires_grad=True, device=device) + weight = torch.randn(N, dtype=dtype, requires_grad=True, device=device) + smooth_scale = torch.rand(N, dtype=torch.float32, requires_grad=False, + device=device) + 0.1 + calibrate = True + + # smooth + q_ref, scale_ref, maxs_ref = torch_rms_and_smooth_quant_forward(x, weight, + smooth_scale=smooth_scale, + round_scale=True) + + q, scale, maxs, rms = triton_rms_norm_and_smooth_quant_forward(x, weight, + smooth_scale=smooth_scale, + calibrate=calibrate, + output_rms=True, + round_scale=True) + output_check(q_ref, q, mode="smooth.data") + output_check(scale_ref, scale, mode='smooth.scale') + if calibrate: + output_check(maxs_ref, maxs, mode="smooth.maxs") + + if bench: + benchmark_func(triton_rms_norm_and_smooth_quant_forward, x, weight, + smooth_scale=smooth_scale, + calibrate=True, + round_scale=True, + output_rms=True, + ref_bytes=M * N * 3) + + def test_rmsnorm_and_block_quant(M=4096, N=4096, bench=False): dtype = torch.bfloat16 device = 'cuda:0' @@ -161,10 +187,9 @@ def test_rmsnorm_and_block_quant(M=4096, N=4096, bench=False): weight = torch.randn(N, dtype=dtype, requires_grad=True, device=device) # blockwise - q_ref, scale_ref, qt_ref, scale_t_ref = torch_rms_and_quant_forward(x, - weight, - smooth_scale=None, - round_scale=True) + q_ref, scale_ref, qt_ref, scale_t_ref = torch_rms_and_block_quant_forward(x, + weight, + round_scale=True) q, scale, rms, q_t, scale_t = triton_rms_norm_and_block_quant_forward(x, weight, round_scale=True, @@ -249,9 +274,14 @@ def test_group_norm_gate_quant(bs=1, length=4096, dim=4096, group_size=4, test_rmsnorm(M=16384, N=2048, bench=False) test_rmsnorm(M=8192, N=4096, bench=False) test_rmsnorm(M=4096, N=8192, bench=False) + test_rmsnorm_and_smooth_quant(M=16384, N=2048, bench=False) + test_rmsnorm_and_smooth_quant(M=8192, N=4096, bench=False) + test_rmsnorm_and_smooth_quant(M=4096, N=8192, bench=False) test_rmsnorm_and_block_quant(M=128, N=2048, bench=False) test_rmsnorm_and_block_quant(M=8192, N=4096, bench=False) test_group_norm_gate_quant(bs=2, length=4096, dim=2048, group_size=4, bench=True) test_group_norm_gate_quant(bs=1, length=4096, dim=4096, group_size=4, bench=True) + + diff --git a/tests/test_silu.py b/tests/test_silu.py index e0e4e24..2375b90 100644 --- a/tests/test_silu.py +++ b/tests/test_silu.py @@ -10,12 +10,18 @@ import torch from linghe.tools.benchmark import benchmark_func -from linghe.utils.silu import ( - triton_batch_weighted_silu_and_block_quant_backward, - triton_batch_weighted_silu_and_block_quant_forward, - triton_silu_and_block_quant_backward, - triton_silu_and_block_quant_forward) -from linghe.tools.util import output_check, torch_group_quant +from linghe.utils.silu import (triton_weighted_silu_forward, + triton_weighted_silu_backward, + triton_batch_weighted_silu_and_smooth_quant_backward, + triton_batch_weighted_silu_and_smooth_quant_forward, + triton_batch_weighted_silu_and_block_quant_backward, + triton_batch_weighted_silu_and_block_quant_forward, + triton_silu_and_smooth_quant_backward, + triton_silu_and_smooth_quant_forward, + triton_silu_and_block_quant_backward, + triton_silu_and_block_quant_forward) +from linghe.tools.util import output_check, torch_smooth_quant, \ + torch_group_quant def torch_silu(x): @@ -40,6 +46,24 @@ def torch_weighted_silu_backward(dy, x, weight): return x.grad, weight.grad +def torch_silu_and_smooth_quant_forward(x, smooth_scale=None, round_scale=True): + M, N = x.shape + x = x.float() + x1, x2 = torch.split(x, N // 2, dim=1) + y = torch.sigmoid(x1) * x1 * x2 + + # smooth + y_q, y_scale, x_maxs = torch_smooth_quant(y, smooth_scale, reverse=False, + round_scale=round_scale) + # y_smooth = y / smooth_scale + # x_maxs = y.abs().float().amax(0) + # y_scale = y_smooth.abs().amax(1) / 448 + # if round_scale: + # y_scale = torch.exp2(torch.ceil(torch.log2(y_scale))) + # y_q = (y_smooth / y_scale[:, None]).to(torch.float8_e4m3fn) + return y_q, y_scale, x_maxs + + def torch_silu_and_block_quant_forward(x, round_scale=True): M, N = x.shape x = x.float() @@ -48,9 +72,26 @@ def torch_silu_and_block_quant_forward(x, round_scale=True): # blockwise y_q, y_scale = torch_group_quant(y, round_scale=round_scale) yt_q, yt_scale = torch_group_quant(y.t(), round_scale=round_scale) - x_maxs = None - return y_q, y_scale, x_maxs, yt_q, yt_scale + return y_q, y_scale, yt_q, yt_scale + + +def torch_silu_and_smooth_quant_backward(grad, x, smooth_scale=None, + transpose_smooth_scale=None, + round_scale=True, reverse=True): + grad = grad.float() + x = x.float().detach().clone().requires_grad_() + y = torch_silu(x) + y.backward(gradient=grad) + dx = x.grad + + q, dx_scale, ms = torch_smooth_quant(dx, smooth_scale, reverse=reverse, + round_scale=round_scale) + yt_q, yt_scale, ms = torch_smooth_quant(dx.t().contiguous(), + transpose_smooth_scale, + reverse=reverse, + round_scale=round_scale) + return q, dx_scale, yt_q, yt_scale def torch_silu_and_block_quant_backward(grad, x, round_scale=True): @@ -66,11 +107,47 @@ def torch_silu_and_block_quant_backward(grad, x, round_scale=True): return q, dx_scale, yt_q, yt_scale +def torch_batch_weighted_silu_and_smooth_quant_forward(xs, weight, + counts, + smooth_scales=None, + round_scale=True, + reverse=False): + counts = counts.tolist() + N = xs.shape[1] + if sum(counts) == 0: + device = xs.device + qs = torch.empty((0, N // 2), device=device, dtype=torch.float8_e4m3fn) + scales = torch.empty((0,), device=device, dtype=torch.float32) + maxs = torch.zeros((len(counts), N), device=device, dtype=torch.float32) + return qs, scales, maxs + + xs = xs.float() + weight = weight.float() + smooth_scales = smooth_scales.float() + + qs = [] + scales = [] + maxs = [] + s = 0 + for i, c in enumerate(counts): + x = xs[s:s + c] + y = torch_weighted_silu(x, weight[s:s + c]) + q, scale, ms = torch_smooth_quant(y, smooth_scales[i], reverse=reverse, + round_scale=round_scale) + qs.append(q) + scales.append(scale) + maxs.append(ms) + + s += c + qs = torch.cat(qs, 0) + scales = torch.cat(scales, 0) + maxs = torch.cat(maxs, 0) + return qs, scales, maxs + + def torch_batch_weighted_silu_and_block_quant_forward(xs, weight, counts, - smooth_scales=None, - round_scale=True, - reverse=False): + round_scale=True): counts = counts.tolist() N = xs.shape[1] if sum(counts) == 0: @@ -98,6 +175,7 @@ def torch_batch_weighted_silu_and_block_quant_forward(xs, weight, scales.append(scale.t().contiguous().view(-1)) qts.append(qt.view(-1)) qtscales.append(qtscale.t().contiguous().view(-1)) + s += c qs = torch.cat(qs, 0) scales = torch.cat(scales, 0) @@ -106,12 +184,64 @@ def torch_batch_weighted_silu_and_block_quant_forward(xs, weight, return qs, scales, qts, qtscales +def torch_batch_weighted_silu_and_smooth_quant_backward(grad_output, x, weight, + counts, + smooth_scales=None, + transpose_smooth_scale=None, + round_scale=True, + reverse=False): + if sum(counts) == 0: + device = x.device + N = x.shape[1] + dx_q = torch.empty((0, N), device=device, dtype=torch.float8_e4m3fn) + dx_scale = torch.empty((0,), device=device, dtype=torch.float32) + dw = torch.empty_like(weight) + qts = torch.empty((0,), device=device, dtype=torch.float8_e4m3fn) + qtscales = torch.zeros((N * len(counts),), device=device, + dtype=torch.float32) + return dx_q, dx_scale, dw, qts, qtscales + + grad_output = grad_output.float() + x = x.float() + weight = weight.float() + smooth_scales = smooth_scales.float() + transpose_smooth_scale = transpose_smooth_scale.float() + + dx, dw = torch_weighted_silu_backward(grad_output, x, weight) + qs = [] + scales = [] + qts = [] + qtscales = [] + s = 0 + for i, c in enumerate(counts): + q, scale, dx_max = torch_smooth_quant(dx[s:s + c], smooth_scales[i], + reverse=reverse, + round_scale=round_scale) + dxt = dx[s:s + c].t().contiguous() + dxt_s = transpose_smooth_scale[s:s + c] + padding_size = (c + 31) // 32 * 32 - c + if padding_size > 0: + dxt = torch.nn.functional.pad(dxt, (0, padding_size, 0, 0)) + dxt_s = torch.nn.functional.pad(dxt_s, (0, padding_size)) + qt, t_scale, dx_max = torch_smooth_quant(dxt, dxt_s, + reverse=reverse, + round_scale=round_scale) + + qs.append(q) + scales.append(scale) + qts.append(qt.view(-1)) + qtscales.append(t_scale.view(-1)) + s += c + dx_q = torch.cat(qs, 0) + dx_scale = torch.cat(scales, 0) + qts = torch.cat(qts, 0) + qtscales = torch.cat(qtscales, 0) + return dx_q, dx_scale, dw, qts, qtscales + + def torch_batch_weighted_silu_and_block_quant_backward(grad_output, x, weight, counts, - smooth_scales=None, - transpose_smooth_scale=None, - round_scale=True, - reverse=False): + round_scale=True): if sum(counts) == 0: device = x.device N = x.shape[1] @@ -126,10 +256,6 @@ def torch_batch_weighted_silu_and_block_quant_backward(grad_output, x, weight, grad_output = grad_output.float() x = x.float() weight = weight.float() - if smooth_scales is not None: - smooth_scales = smooth_scales.float() - if transpose_smooth_scale is not None: - transpose_smooth_scale = transpose_smooth_scale.float() dx, dw = torch_weighted_silu_backward(grad_output, x, weight) qs = [] @@ -145,6 +271,7 @@ def torch_batch_weighted_silu_and_block_quant_backward(grad_output, x, weight, scales.append(scale.t().contiguous().view(-1)) qts.append(qt.view(-1)) qtscales.append(qtscale.t().contiguous().view(-1)) + s += c dx_q = torch.cat(qs, 0) dx_scale = torch.cat(scales, 0) @@ -153,14 +280,106 @@ def torch_batch_weighted_silu_and_block_quant_backward(grad_output, x, weight, return dx_q, dx_scale, dw, qts, qtscales -def test_silu_and_block_quant(M=4096, N=4096, bench=False): +def test_weighted_silu(M=4096, N=4096, bench=False): x = torch.randn((M, N), dtype=torch.bfloat16, device='cuda:0') - x = (x * 10).clone().detach().requires_grad_() + x = (x ** 3 // 10).clone().detach().requires_grad_() + weight = torch.randn((M, 1), dtype=torch.bfloat16, device='cuda:0') grad_output = torch.randn((M, N // 2), dtype=torch.bfloat16, device='cuda:0') + ref_y = torch_weighted_silu(x, weight) + y = triton_weighted_silu_forward(x, weight) + output_check(ref_y, y, 'y') + + dx_ref, dw_ref = torch_weighted_silu_backward(grad_output, x, weight) + dx, dw = triton_weighted_silu_backward(grad_output, x, weight) + output_check(dx_ref, dx, 'dx') + output_check(dw_ref, dw, 'dw') + + if bench: + benchmark_func(triton_weighted_silu_forward, x, weight, n_repeat=100, + ref_bytes=M * N * 3) + benchmark_func(triton_weighted_silu_backward, grad_output, x, weight, + n_repeat=100, ref_bytes=M * N * 5) + + +def test_silu_and_smooth_quant(M=4096, N=4096, bench=False): + if True: + x = torch.randn((M, N), dtype=torch.bfloat16, device='cuda:0') + x = (x * 10).clone().detach().requires_grad_() + grad_output = torch.randn((M, N // 2), dtype=torch.bfloat16, + device='cuda:0') + smooth_scale = 1 + torch.rand((N // 2,), dtype=torch.float32, + device='cuda:0') + grad_smooth_scale = 1 + torch.rand((N,), dtype=torch.float32, + device='cuda:0') + transpose_grad_smooth_scale = 1 + torch.rand((M,), dtype=torch.float32, + device='cuda:0') + else: + d = torch.load('/ossfs/workspace/tmp/vis/silu.bin') + x = d['x'].clone().detach().to('cuda:0').requires_grad_() + grad_output = d['g'].to('cuda:0') + grad_smooth_scale = d['smooth_scale'].to('cuda:0') + N = x.shape[-1] + M = x.shape[0] + smooth_scale = 1 + torch.rand((N // 2,), dtype=torch.float32, + device='cuda:0') + + y_q_ref, y_scale_ref, y_maxs_ref = torch_silu_and_smooth_quant_forward(x, + smooth_scale=smooth_scale) + y_q, y_scale, y_maxs = triton_silu_and_smooth_quant_forward(x, + smooth_scale=smooth_scale, + round_scale=True, + calibrate=True) + output_check(y_q_ref.float(), y_q.float(), 'smooth.y_q') + output_check(y_scale_ref, y_scale, 'smooth.y_scale') + output_check(y_maxs_ref, y_maxs, 'smooth.y_max') + + dx_q_ref, dx_scale_ref, dxt_q_ref, dxt_scale_ref = torch_silu_and_smooth_quant_backward( + grad_output, x, + smooth_scale=grad_smooth_scale, + transpose_smooth_scale=transpose_grad_smooth_scale, + reverse=True, + round_scale=True) + dx_q, dx_scale, dxt_q, dxt_scale = triton_silu_and_smooth_quant_backward( + grad_output, x, + smooth_scale=grad_smooth_scale, + transpose_smooth_scale=transpose_grad_smooth_scale, + reverse=True, + round_scale=True) + + output_check(dx_q_ref.float(), dx_q.float(), 'smooth.dx_data') + output_check(dx_scale_ref, dx_scale, 'smooth.dx_scale') + output_check(dxt_q_ref.float(), dxt_q.float(), 'smooth.dxt_data') + output_check(dxt_scale_ref, dxt_scale, 'smooth.dxt_scale') + + if bench: + benchmark_func(torch_silu_and_smooth_quant_forward, x, + smooth_scale=smooth_scale, + n_repeat=100, ref_bytes=M * N * 2.5) + benchmark_func(triton_silu_and_smooth_quant_forward, x, + smooth_scale=smooth_scale, + n_repeat=100, ref_bytes=M * N * 2.5) + benchmark_func(triton_silu_and_smooth_quant_backward, grad_output, x, + smooth_scale=grad_smooth_scale, + transpose_smooth_scale=transpose_grad_smooth_scale, + n_repeat=100, ref_bytes=M * N * 5) + - y_q_ref, y_scale_ref, _, yt_q_ref, yt_scale_ref = torch_silu_and_block_quant_forward( - x) +def test_silu_and_block_quant(M=4096, N=4096, bench=False): + if True: + x = torch.randn((M, N), dtype=torch.bfloat16, device='cuda:0') + x = (x * 10).clone().detach().requires_grad_() + grad_output = torch.randn((M, N // 2), dtype=torch.bfloat16, + device='cuda:0') + else: + d = torch.load('/ossfs/workspace/tmp/vis/silu.bin') + x = d['x'].clone().detach().to('cuda:0').requires_grad_() + grad_output = d['g'].to('cuda:0') + N = x.shape[-1] + M = x.shape[0] + + y_q_ref, y_scale_ref, yt_q_ref, yt_scale_ref = torch_silu_and_block_quant_forward( + x, round_scale=True) y_q, y_scale, yt_q, yt_scale = triton_silu_and_block_quant_forward(x, round_scale=True, output_mode=2) @@ -199,16 +418,107 @@ def test_silu_and_block_quant(M=4096, N=4096, bench=False): n_repeat=100, ref_bytes=M * N * 5) +def test_triton_batch_weighted_silu_and_smooth_quant(M=1024, N=4096, + n_experts=32, + bench=False): + if True: + count_list = [random.randint(M // 2, M // 2 * 3) // 16 * 16 for _ in + range(n_experts)] + counts = torch.tensor(count_list, device='cuda:0', dtype=torch.int32) + bs = sum(count_list) + + x = torch.randn((bs, N), dtype=torch.bfloat16, device='cuda:0') ** 3 / 4 + weight = torch.randn((bs, 1), dtype=torch.float32, device='cuda:0') + smooth_scales = 1 + torch.rand((n_experts, N // 2), dtype=torch.float32, + device='cuda:0') * 10 + else: + d = torch.load('/ossfs/workspace/Megatron-LM/silu.bin') + counts = d['counts'].cuda() + x = d['x'].cuda() + weight = d['weight'].cuda() + smooth_scales = d['smooth_scale'].cuda() + bs = sum(counts.tolist()) + N = x.shape[-1] + n_experts = counts.shape[0] + + grad_output = torch.randn((bs, N // 2), dtype=torch.bfloat16, + device='cuda:0') ** 3 + grad_smooth_scales = 1 + torch.rand((n_experts, N), dtype=torch.float32, + device='cuda:0') * 10 + transpose_grad_smooth_scales = 1 + torch.rand((bs,), dtype=torch.float32, + device='cuda:0') * 10 + round_scale = True + + x_q_ref, x_scale_ref, x_max_ref = torch_batch_weighted_silu_and_smooth_quant_forward( + x, + weight, + counts, + smooth_scales=smooth_scales, + round_scale=round_scale, + reverse=False) + x_q, x_scale, maxs = triton_batch_weighted_silu_and_smooth_quant_forward(x, + weight, + counts, + smooth_scale=smooth_scales, + round_scale=round_scale, + reverse=False) + output_check(x_q_ref.float(), x_q.float(), 'smooth.data') + output_check(x_scale_ref, x_scale, 'smooth.scale') + + dx_ref, dx_scale_ref, dw_ref, dxt_ref, dxt_scale_ref = torch_batch_weighted_silu_and_smooth_quant_backward( + grad_output, x, weight, count_list, + smooth_scales=grad_smooth_scales, + transpose_smooth_scale=transpose_grad_smooth_scales, + round_scale=round_scale, reverse=False) + dx, dx_scale, dw, dxt, dxt_scale = triton_batch_weighted_silu_and_smooth_quant_backward( + grad_output, x, weight, counts, + smooth_scale=grad_smooth_scales, + transpose_smooth_scale=transpose_grad_smooth_scales, + splits=count_list, + round_scale=round_scale, + reverse=False) + output_check(dx_ref.float(), dx.float(), 'smooth.dx') + output_check(dx_scale_ref, dx_scale, 'smooth.dx_scale') + output_check(dw_ref, dw, 'smooth.dw') + output_check(dxt_ref.float(), dxt.float(), 'smooth.dxt') + output_check(dxt_scale_ref, dxt_scale.view(-1), 'smooth.dxt_scale') + + if bench: + ref_time = None + benchmark_func(triton_batch_weighted_silu_and_smooth_quant_forward, x, + weight, + counts, smooth_scale=smooth_scales, round_scale=True, + ref_bytes=n_experts * M * N * 2.5, ref_time=ref_time) + benchmark_func(triton_batch_weighted_silu_and_smooth_quant_backward, + grad_output, x, weight, counts, + smooth_scale=smooth_scales, + transpose_smooth_scale=transpose_grad_smooth_scales, + splits=count_list, + round_scale=True, + ref_bytes=n_experts * M * N * 4, ref_time=ref_time) + + def test_triton_batch_weighted_silu_and_block_quant(M=1024, N=4096, n_experts=32, bench=False): - count_list = [random.randint(M // 2, M // 2 * 3) // 16 * 16 for _ in - range(n_experts)] - counts = torch.tensor(count_list, device='cuda:0', dtype=torch.int32) - bs = sum(count_list) - - x = torch.randn((bs, N), dtype=torch.bfloat16, device='cuda:0') ** 3 / 10 - weight = torch.randn((bs, 1), dtype=torch.float32, device='cuda:0') + if True: + count_list = [random.randint(M // 2, M // 2 * 3) // 16 * 16 for _ in + range(n_experts)] + counts = torch.tensor(count_list, device='cuda:0', dtype=torch.int32) + bs = sum(count_list) + + x = torch.randn((bs, N), dtype=torch.bfloat16, + device='cuda:0') ** 3 / 10 + weight = torch.randn((bs, 1), dtype=torch.float32, device='cuda:0') + else: + d = torch.load('/ossfs/workspace/Megatron-LM/silu.bin') + counts = d['counts'].cuda() + x = d['x'].cuda() + weight = d['weight'].cuda() + smooth_scales = d['smooth_scale'].cuda() + bs = sum(counts.tolist()) + N = x.shape[-1] + n_experts = counts.shape[0] grad_output = torch.randn((bs, N // 2), dtype=torch.bfloat16, device='cuda:0') ** 3 @@ -220,10 +530,13 @@ def test_triton_batch_weighted_silu_and_block_quant(M=1024, N=4096, counts, round_scale=round_scale) x_q, x_scale, xt_q, xt_scale = triton_batch_weighted_silu_and_block_quant_forward( - x, weight, + x, + weight, counts, + count_list, round_scale=round_scale, - splits=count_list) + output_mode=2) + output_check(x_q_ref.float(), x_q.float(), 'block.q') output_check(x_scale_ref, x_scale, 'block.scale') output_check(xt_q_ref.float(), xt_q.float(), 'block.qt') @@ -248,11 +561,6 @@ def test_triton_batch_weighted_silu_and_block_quant(M=1024, N=4096, counts, round_scale=True, splits=count_list, output_mode=0, n_repeat=100, ref_bytes=n_experts * M * N * 2.5, ref_time=ref_time) - benchmark_func(triton_batch_weighted_silu_and_block_quant_forward, x, - weight, - counts, round_scale=True, splits=count_list, - output_mode=1, n_repeat=100, - ref_bytes=n_experts * M * N * 2.5, ref_time=ref_time) benchmark_func(triton_batch_weighted_silu_and_block_quant_forward, x, weight, counts, round_scale=True, splits=count_list, @@ -265,9 +573,20 @@ def test_triton_batch_weighted_silu_and_block_quant(M=1024, N=4096, if __name__ == '__main__': + test_weighted_silu(M=16384, N=1024, bench=True) + + test_silu_and_smooth_quant(M=16384, N=1024, bench=False) + test_silu_and_smooth_quant(M=8192, N=2048, bench=False) + test_silu_and_smooth_quant(M=4096, N=10240, bench=False) + test_silu_and_smooth_quant(M=4096, N=5120, bench=False) + test_silu_and_block_quant(M=16384, N=1024, bench=True) + test_triton_batch_weighted_silu_and_smooth_quant(M=2048, N=2048, + n_experts=32, bench=False) + test_triton_batch_weighted_silu_and_smooth_quant(M=800, N=2048, n_experts=32, bench=False) + test_triton_batch_weighted_silu_and_smooth_quant(M=0, N=2048, n_experts=32, bench=False) + test_triton_batch_weighted_silu_and_block_quant(M=4096, N=2048, n_experts=32, bench=True) - test_triton_batch_weighted_silu_and_block_quant(M=1008, N=2048, - n_experts=32, bench=False) + test_triton_batch_weighted_silu_and_block_quant(M=1008, N=2048, n_experts=32, bench=False) diff --git a/tests/test_smooth_quant.py b/tests/test_smooth_quant.py new file mode 100644 index 0000000..4ddc166 --- /dev/null +++ b/tests/test_smooth_quant.py @@ -0,0 +1,328 @@ +# -*- coding: utf-8 -*- +""" +Copyright (c) Ant Financial Service Group and its affiliates. +""" + +import torch + +from linghe.quant.smooth import (triton_batch_smooth_quant, + triton_subrow_smooth_quant, + triton_transpose_rescale_smooth_quant, + triton_smooth_quant, + triton_transpose_smooth_quant) +from linghe.tools.benchmark import benchmark_func +from linghe.tools.util import (output_check, + torch_make_indices, + torch_smooth_quant, + round_up) + + +def torch_split_smooth_quant(x_split, smooth_scales, round_scale=False): + x_qs = [] + x_scales = [] + x_maxs = [] + for i, x_ in enumerate(x_split): + x_maxs.append(x_.abs().amax(0)) + x_smooth = x_ / smooth_scales[i] + x_scale_ = x_smooth.float().abs().amax(1) / 448 + if round_scale: + x_scale_ = torch.exp2(torch.ceil(torch.log2(x_scale_))) + x_q_ = (x_smooth / x_scale_[:, None]).to(torch.float8_e4m3fn) + x_qs.append(x_q_) + x_scales.append(x_scale_) + x_maxs = torch.stack(x_maxs, 0) + return x_qs, x_scales, x_maxs + + +def torch_subrow_smooth_quant(x, smooth_scale, x_q, x_scale, subrow_scales, + offset, size, + reverse=False, round_scale=False): + limit = 448 * torch.ones((1,), dtype=smooth_scale.dtype, + device=smooth_scale.device) + # subrow_scales is saved as 448/max + + M, N = x_q.shape + if offset % N > 0: + si = offset % N + k = N - si + x_slice = x.view(-1)[0:k] + smooth_scale_slice = smooth_scale[si: N] + if not reverse: + smooth_scale_slice = 1 / smooth_scale_slice + x_smooth = x_slice * smooth_scale_slice + + scale = subrow_scales[0:1] + if round_scale: + scale = torch.exp2(torch.floor(torch.log2(scale))) + + x_q_slice = torch.minimum(torch.maximum(x_smooth / scale, -limit), + limit).to(torch.float8_e4m3fn) + x_q.view(-1)[offset:offset + k] = x_q_slice + + if (offset + size) % N > 0: + k = (offset + size) % N + x_slice = x.view(-1)[-k:] + smooth_scale_slice = smooth_scale[0: k] + if not reverse: + smooth_scale_slice = 1 / smooth_scale_slice + x_smooth = x_slice * smooth_scale_slice + scale = subrow_scales[1:2] + if round_scale: + scale = torch.exp2(torch.floor(torch.log2(scale))) + x_q_slice = torch.minimum(torch.maximum(x_smooth / scale, -limit), + limit).to(torch.float8_e4m3fn) + x_q.view(-1)[(offset + size - k):(offset + size)] = x_q_slice + x_scale[(offset + size) // N] = scale + + +def torch_rescale_quant(y_q, org_smooth_scale, y_scale, transpose_smooth_scale, + reverse=True, round_scale=True): + assert reverse + y = y_q.float() / org_smooth_scale * y_scale[:, None] + y_q, y_scale, _ = torch_smooth_quant(y.t(), transpose_smooth_scale, + reverse=True, round_scale=round_scale) + return y_q, y_scale + + +def triton_split_smooth_quant(x_split, smooth_scales): + x_qs = [] + x_scales = [] + for i, x_ in enumerate(x_split): + x_q_, x_scale_, _ = triton_smooth_quant(x_, smooth_scales[i]) + x_qs.append(x_q_) + x_scales.append(x_scale_) + return x_qs, x_scales + + +def test_triton_smooth_quant(M=4096, N=4096, bench=False): + device = 'cuda:0' + x = torch.randn((M, N), dtype=torch.bfloat16, device=device) + smooth_scale = torch.randn((N,), device=device, dtype=torch.float32).abs() + x_q_ref, scales_ref, x_maxs_ref = torch_smooth_quant(x, smooth_scale, + reverse=False, + round_scale=True) + + x_q, x_scale, x_maxs = triton_smooth_quant(x, smooth_scale, + reverse=False, + round_scale=True, + calibrate=True) + output_check(x_q_ref.float(), x_q.float(), + 'triton_smooth_quant.data') + output_check(scales_ref, x_scale, 'triton_smooth_quant.scale') + output_check(x_maxs_ref, x_maxs, 'triton_smooth_quant.x_maxs') + + if bench: + benchmark_func(triton_smooth_quant, x, + smooth_scale, + reverse=False, + round_scale=True, + calibrate=False, + ref_bytes=M * N * 3) + + +def test_triton_subrow_smooth_quant(M=4096, N=5120, offset=4096, + size=16384): + device = 'cuda:0' + x = torch.randn((size,), dtype=torch.float32, device=device) + x_q = torch.zeros((M, N), dtype=torch.bfloat16, device=device).to( + torch.float8_e4m3fn) + x_scale = torch.zeros((M,), dtype=torch.float32, device=device).abs() + smooth_scale = torch.randn((N,), device=device, + dtype=torch.float32).abs() + 1 + subrow_scales = torch.randn((2,), device=device, + dtype=torch.float32).abs() + 1 + + x_ref = x.clone() + x_q_ref = x_q.clone() + x_scale_ref = x_scale.clone() + subrow_scales_ref = subrow_scales.clone() + torch_subrow_smooth_quant(x_ref, smooth_scale, x_q_ref, x_scale_ref, + subrow_scales_ref, offset, size, + reverse=False, round_scale=False) + + triton_subrow_smooth_quant(x, smooth_scale, x_q, x_scale, + subrow_scales, offset, size, + reverse=False, round_scale=False) + + output_check(x_q_ref.float(), x_q.float(), 'subrow.data') + output_check(x_scale_ref, x_scale, 'subrow.scale') + + if offset % N > 0: + k = N - offset % N + output_check(x_q_ref.float().view(-1)[offset:offset + k], + x_q.float().view(-1)[offset:offset + k], + 'subrow.data.tail') + + if (offset + size) % N > 0: + k = (offset + size) % N + output_check(x_q_ref.float().view(-1)[offset + size - k:offset + size], + x_q.float().view(-1)[offset + size - k:offset + size], + 'subrow.data.head') + row_id = (offset + size) // N + output_check(x_scale_ref[row_id], x_scale[row_id], 'subrow.scale.slice') + + +def test_triton_transpose_smooth_quant(M=4096, N=4096, bench=False): + device = 'cuda:0' + P = round_up(M, b=32) + y = torch.randn((M, N), dtype=torch.bfloat16, device=device) ** 3 * 1e-10 + transpose_smooth_scale = torch.randn((M,), device=device, + dtype=torch.float32).abs() * 10 + 1 + yt_q, yt_scale = triton_transpose_smooth_quant(y, + transpose_smooth_scale, + reverse=True, + pad=True, + round_scale=True) + q_ref, scale_ref, maxs_ref = torch_smooth_quant(y.T.contiguous(), + transpose_smooth_scale, + reverse=True, + round_scale=True) + + assert yt_q.shape[1] == P + if P > M: + assert yt_q.float()[:, M:].abs().sum().item() == 0 + output_check(q_ref, yt_q[:, :M], + 'triton_transpose_smooth_quant.data') + output_check(scale_ref, yt_scale, + 'triton_transpose_smooth_quant.scale') + + if bench: + benchmark_func(triton_transpose_smooth_quant, y, + transpose_smooth_scale, + reverse=True, + pad=True, + round_scale=True, + ref_bytes=M * N * 3) + + +def test_triton_transpose_rescale_smooth_quant(M=4096, N=4096, + round_scale=False): + device = 'cuda:0' + P = round_up(M, b=32) + y = torch.randn((M, N), dtype=torch.bfloat16, device=device) ** 3 + org_smooth_scale = torch.randn((N,), device=device, + dtype=torch.float32).abs() * 10 + 1 + if round_scale: + org_smooth_scale = torch.exp2(torch.ceil(torch.log2(org_smooth_scale))) + transpose_smooth_scale = torch.randn((M,), device=device, + dtype=torch.float32).abs() + 0.1 + if round_scale: + transpose_smooth_scale = torch.exp2( + torch.ceil(torch.log2(transpose_smooth_scale))) + + y_q, y_scale, y_maxs = triton_smooth_quant(y, org_smooth_scale, + reverse=True, + round_scale=round_scale) + + yt_gt, yt_scale_gt, yt_maxs_gt = torch_smooth_quant(y.t(), + transpose_smooth_scale, + reverse=True, + round_scale=round_scale) + + yt_q_ref, yt_scale_ref = torch_rescale_quant(y_q, org_smooth_scale, y_scale, + transpose_smooth_scale, + reverse=True, + round_scale=round_scale) + + yt_q, yt_scale = triton_transpose_rescale_smooth_quant(y_q, + org_smooth_scale, + y_scale, + transpose_smooth_scale, + reverse=True, + pad=True, + round_scale=round_scale) + + if P > M: + assert yt_q.shape[1] == P + yt_q.float()[:, M:].abs().sum().item() == 0 + + output_check(yt_q_ref, yt_q[:, :M], + 'triton_transpose_rescale_smooth_quant.data') + output_check(yt_scale_ref, yt_scale, + 'triton_transpose_rescale_smooth_quant.scale') + + # should dequant and compare with gt + # output_check(yt_gt, yt_q[:, :M], + # 'triton_transpose_rescale_smooth_quant.data.gt') + # output_check(yt_scale_gt, yt_scale, + # 'triton_transpose_rescale_smooth_quant.scale.gt') + + +def test_triton_batch_smooth_quant(M=4096, N=4096, n_experts=32, topk=8, + round_scale=False, bench=False): + device = 'cuda:0' + + smooth_scales = 1 + 10 * torch.rand((n_experts, N), device=device, + dtype=torch.float32) + + logits = torch.randn((M, n_experts), dtype=torch.float32, device=device) + probs, mask_map, token_count_per_expert, indices, row_id_map = torch_make_indices( + logits, topk=topk, bias=0.0) + token_count_per_expert_list = token_count_per_expert.tolist() + x = torch.randn((sum(token_count_per_expert_list), N), dtype=torch.bfloat16, + device=device) + + x_q, x_scale, x_maxs = triton_batch_smooth_quant(x, smooth_scales, + token_count_per_expert, + reverse=False, + round_scale=round_scale, + calibrate=True) + + x_split = torch.split(x, token_count_per_expert_list) + x_q_ref, x_scale_ref, x_maxs_ref = torch_split_smooth_quant(x_split, + smooth_scales) + x_q_ref = torch.cat([x.view(torch.uint8) for x in x_q_ref], 0).view( + torch.float8_e4m3fn) + x_scale_ref = torch.cat(x_scale_ref, 0) + output_check(x_q_ref.float(), x_q.float(), 'triton_batch_smooth_quant.data') + output_check(x_scale_ref.float(), x_scale.float(), + 'triton_batch_smooth_quant.scale') + output_check(x_maxs_ref.float(), x_maxs.float(), + 'triton_batch_smooth_quant.maxs') + + if bench: + n_repeat = 100 + ref_time = benchmark_func(triton_split_smooth_quant, x_split, + smooth_scales, n_repeat=n_repeat) + benchmark_func(triton_batch_smooth_quant, x, smooth_scales, + token_count_per_expert, reverse=False, + round_scale=round_scale, n_repeat=n_repeat, + ref_time=ref_time) + benchmark_func(triton_batch_smooth_quant, x, smooth_scales, + token_count_per_expert, reverse=False, + round_scale=round_scale, calibrate=True, + n_repeat=n_repeat, ref_time=ref_time) + + +if __name__ == '__main__': + test_triton_smooth_quant(M=16384, N=2048, bench=False) + test_triton_smooth_quant(M=8192, N=4096, bench=False) + test_triton_smooth_quant(M=4096, N=8192, bench=False) + test_triton_smooth_quant(M=8192, N=3072, bench=False) + test_triton_smooth_quant(M=8192, N=6144, bench=False) + test_triton_smooth_quant(M=16384, N=512, bench=False) + test_triton_smooth_quant(M=3457, N=512, bench=False) + + test_triton_subrow_smooth_quant(M=4096, N=5120, offset=5120, + size=2048) + test_triton_subrow_smooth_quant(M=4096, N=5120, offset=4096, + size=5120) + test_triton_subrow_smooth_quant(M=4096, N=5120, offset=5120, + size=5120 * 10 - 1024) + + test_triton_transpose_smooth_quant(M=16384, N=2048, bench=False) + test_triton_transpose_smooth_quant(M=8192, N=4096, bench=False) + test_triton_transpose_smooth_quant(M=4096, N=8192, bench=False) + test_triton_transpose_smooth_quant(M=4096, N=3072, bench=False) + + test_triton_transpose_rescale_smooth_quant(M=4096, N=4096, + round_scale=True) + test_triton_transpose_rescale_smooth_quant(M=3895, N=4096, + round_scale=True) + test_triton_transpose_rescale_smooth_quant(M=4096, N=3072, + round_scale=True) + test_triton_transpose_rescale_smooth_quant(M=395, N=2048, + round_scale=True) + + test_triton_batch_smooth_quant(M=4096, N=4096, n_experts=32, topk=8, + round_scale=False)