diff --git a/build.sh b/build.sh index b123d1c..a7c77bb 100644 --- a/build.sh +++ b/build.sh @@ -4,4 +4,4 @@ rm -rf linghe.egg-info && python setup.py develop && python setup.py bdist_wheel -# pdoc --output-dir docs -d google --no-include-undocumented --no-search --no-show-source linghe \ No newline at end of file +# pdoc --output-dir docs -d google --no-include-undocumented --no-show-source linghe \ No newline at end of file diff --git a/docs/linghe.html b/docs/linghe.html index a7cef32..27dbb27 100644 --- a/docs/linghe.html +++ b/docs/linghe.html @@ -18,6 +18,8 @@
+

Submodules

@@ -48,5 +50,186 @@

- + \ No newline at end of file diff --git a/docs/linghe/facade.html b/docs/linghe/facade.html index e50704a..ad2c300 100644 --- a/docs/linghe/facade.html +++ b/docs/linghe/facade.html @@ -23,6 +23,8 @@  linghe +

Submodules

@@ -57,5 +59,186 @@

- + \ No newline at end of file diff --git a/docs/linghe/facade/add.html b/docs/linghe/facade/add.html index 57a6c73..185a77c 100644 --- a/docs/linghe/facade/add.html +++ b/docs/linghe/facade/add.html @@ -23,6 +23,8 @@  linghe.facade + @@ -76,12 +78,193 @@

Arguments:
Returns:
-

return updated x tensor

+

updated x tensor

- + \ No newline at end of file diff --git a/docs/linghe/facade/fp32_gemm.html b/docs/linghe/facade/fp32_gemm.html index b4858e4..86fe2a4 100644 --- a/docs/linghe/facade/fp32_gemm.html +++ b/docs/linghe/facade/fp32_gemm.html @@ -23,6 +23,8 @@  linghe.facade + @@ -84,5 +86,186 @@
Returns:
- + \ No newline at end of file diff --git a/docs/linghe/facade/hadamard_quant_linear.html b/docs/linghe/facade/hadamard_quant_linear.html index 530195c..a713023 100644 --- a/docs/linghe/facade/hadamard_quant_linear.html +++ b/docs/linghe/facade/hadamard_quant_linear.html @@ -23,6 +23,8 @@  linghe.facade + @@ -34,12 +36,6 @@

API Documentation

  • HadamardQuantLinear
  • -
  • - forward -
  • -
  • - extra_repr -
  • @@ -76,41 +72,7 @@

    -

    Base class for all neural network modules.

    - -

    Your models should also subclass this class.

    - -

    Modules can also contain other Modules, allowing them to be nested in -a tree structure. You can assign the submodules as regular attributes::

    - -
    import torch.nn as nn
    -import torch.nn.functional as F
    -
    -
    -class Model(nn.Module):
    -    def __init__(self) -> None:
    -        super().__init__()
    -        self.conv1 = nn.Conv2d(1, 20, 5)
    -        self.conv2 = nn.Conv2d(20, 20, 5)
    -
    -    def forward(self, x):
    -        x = F.relu(self.conv1(x))
    -        return F.relu(self.conv2(x))
    -
    - -

    Submodules assigned in this way will be registered, and will also have their -parameters converted when you call to(), etc.

    - -
    - -

    As per the example above, an __init__() call to the parent class -must be made before assignment on the child.

    - -
    - -

    :ivar training: Boolean represents whether this module is in training or - evaluation mode. -:vartype training: bool

    +

    a naive implementation of hadamard transformation and quantization

    @@ -123,9 +85,7 @@

    -

    a naive implementation of hadamard transformation and quantization

    - -
    Arguments:
    +
    Arguments:
    • in_features: in feature number
    • @@ -133,58 +93,193 @@
      Arguments:
    • bias: whether use bias
    • device: weight device
    • dtype: weight dtype
    • -
    • impl: implementation of hadamard quantization
    -
    -
    -
    - - def - forward(self, input: torch.Tensor) -> torch.Tensor: - - -
    - - -

    Define the computation performed at every call.

    - -

    Should be overridden by all subclasses.

    - -
    - -

    Although the recipe for forward pass needs to be defined within -this function, one should call the Module instance afterwards -instead of this since the former takes care of running the -registered hooks while the latter silently ignores them.

    - -
    -
    - - -
    -
    -
    - - def - extra_repr(self) -> str: - - -
    - - -

    Return the extra representation of the module.

    - -

    To print customized extra information, you should re-implement -this method in your own modules. Both single-line and multi-line -strings are acceptable.

    -
    - -
    - + \ No newline at end of file diff --git a/docs/linghe/facade/loss.html b/docs/linghe/facade/loss.html index 21cda4a..8b6d236 100644 --- a/docs/linghe/facade/loss.html +++ b/docs/linghe/facade/loss.html @@ -23,6 +23,8 @@  linghe.facade + @@ -77,12 +79,193 @@

    Arguments:
    Returns:
    -

    per token loss

    +

    a tensor of per token loss

    - + \ No newline at end of file diff --git a/docs/linghe/facade/norm.html b/docs/linghe/facade/norm.html index e858009..9f264bf 100644 --- a/docs/linghe/facade/norm.html +++ b/docs/linghe/facade/norm.html @@ -23,6 +23,8 @@  linghe.facade + @@ -118,5 +120,186 @@
    Returns:
    - + \ No newline at end of file diff --git a/docs/linghe/facade/rope.html b/docs/linghe/facade/rope.html index 0d39eca..b9189f9 100644 --- a/docs/linghe/facade/rope.html +++ b/docs/linghe/facade/rope.html @@ -23,6 +23,8 @@  linghe.facade + @@ -81,14 +83,197 @@
    Arguments:
    Returns:
    -

    qo: shape [B, S, H, head_dim] - ko: shape [B, S, h, head_dim] - vo: shape [B, S, h, head_dim]

    +
    - + \ No newline at end of file diff --git a/docs/linghe/facade/smooth_quant_linear.html b/docs/linghe/facade/smooth_quant_linear.html index 2752e47..05abcc0 100644 --- a/docs/linghe/facade/smooth_quant_linear.html +++ b/docs/linghe/facade/smooth_quant_linear.html @@ -23,22 +23,18 @@  linghe.facade +

    API Documentation

    -

    Returns:

    +
    Returns:
    + +
    +

    output tensor

    +
    @@ -218,5 +224,186 @@
    Returns:
    - + \ No newline at end of file diff --git a/docs/linghe/quant.html b/docs/linghe/quant.html index bd6573d..4264812 100644 --- a/docs/linghe/quant.html +++ b/docs/linghe/quant.html @@ -23,6 +23,8 @@  linghe +

    Submodules

    @@ -54,5 +56,186 @@

    - + \ No newline at end of file diff --git a/docs/linghe/quant/block.html b/docs/linghe/quant/block.html index a562b18..957cbb2 100644 --- a/docs/linghe/quant/block.html +++ b/docs/linghe/quant/block.html @@ -23,6 +23,8 @@  linghe.quant + @@ -77,13 +79,196 @@

    Arguments:
    Returns:
    -

    y: quantized tensor, float8_e4m3fn - s: quantization scale, float32

    +
    - + \ No newline at end of file diff --git a/docs/linghe/quant/channel.html b/docs/linghe/quant/channel.html index 80243be..d4f506a 100644 --- a/docs/linghe/quant/channel.html +++ b/docs/linghe/quant/channel.html @@ -23,6 +23,8 @@  linghe.quant + @@ -140,13 +142,196 @@
    Arguments:
    Returns:
    -

    x_q: quantized tensor - x_scale: quantization scale

    +
    - + \ No newline at end of file diff --git a/docs/linghe/quant/group.html b/docs/linghe/quant/group.html index e0191d6..9d311e9 100644 --- a/docs/linghe/quant/group.html +++ b/docs/linghe/quant/group.html @@ -23,6 +23,8 @@  linghe.quant + @@ -77,13 +79,196 @@
    Arguments:
    Returns:
    -

    y: quantized tensor, float8_e4m3fn - s: quantization scale, float32

    +
    - + \ No newline at end of file diff --git a/docs/linghe/quant/hadamard.html b/docs/linghe/quant/hadamard.html index 917cf30..d956495 100644 --- a/docs/linghe/quant/hadamard.html +++ b/docs/linghe/quant/hadamard.html @@ -23,6 +23,8 @@  linghe.quant + @@ -76,15 +78,198 @@
    Arguments:
    Returns:
    -

    x_q: rowwise quantized tensor of non-transposed x - x_scale: rowwise quantization scale of non-transposed x - xt_q: columnwise quantized tensor of transposed x - xt_scale: columnwise quantization scale of transposed x

    +
    - + \ No newline at end of file diff --git a/docs/linghe/quant/smooth.html b/docs/linghe/quant/smooth.html index c903c18..5922806 100644 --- a/docs/linghe/quant/smooth.html +++ b/docs/linghe/quant/smooth.html @@ -23,6 +23,8 @@  linghe.quant + @@ -52,5 +54,186 @@

    - + \ No newline at end of file diff --git a/docs/linghe/utils.html b/docs/linghe/utils.html index e37e3fe..afd4a60 100644 --- a/docs/linghe/utils.html +++ b/docs/linghe/utils.html @@ -23,6 +23,8 @@  linghe +

    Submodules

    @@ -60,5 +62,186 @@

    - + \ No newline at end of file diff --git a/docs/linghe/utils/add.html b/docs/linghe/utils/add.html index 9c8fe47..e84d1b8 100644 --- a/docs/linghe/utils/add.html +++ b/docs/linghe/utils/add.html @@ -23,6 +23,8 @@  linghe.utils + @@ -84,5 +86,186 @@

    Returns:
    - + \ No newline at end of file diff --git a/docs/linghe/utils/dot.html b/docs/linghe/utils/dot.html index 49a7b28..67b62f7 100644 --- a/docs/linghe/utils/dot.html +++ b/docs/linghe/utils/dot.html @@ -23,6 +23,8 @@  linghe.utils + @@ -84,5 +86,186 @@
    Returns:
    - + \ No newline at end of file diff --git a/docs/linghe/utils/gather.html b/docs/linghe/utils/gather.html index 8f18e38..a7444c8 100644 --- a/docs/linghe/utils/gather.html +++ b/docs/linghe/utils/gather.html @@ -23,6 +23,8 @@  linghe.utils + @@ -315,7 +317,7 @@
    Arguments:
    -

    gather and optional dequant and smooth quant

    +

    gather ( and optional dequant) and smooth quant

    Arguments:
    @@ -332,11 +334,199 @@
    Arguments:
  • round_scale:
  • -

    Returns:

    +
    Returns:
    + +
    +
      +
    • output: output tensor
    • +
    • permuted_scale: permuted scale if scale is not None
    • +
    +
    - + \ No newline at end of file diff --git a/docs/linghe/utils/loss.html b/docs/linghe/utils/loss.html index 73e62df..6ca4c4f 100644 --- a/docs/linghe/utils/loss.html +++ b/docs/linghe/utils/loss.html @@ -23,6 +23,8 @@  linghe.utils + @@ -117,5 +119,186 @@
    Returns:
    - + \ No newline at end of file diff --git a/docs/linghe/utils/norm.html b/docs/linghe/utils/norm.html index 63e9fb9..4db407c 100644 --- a/docs/linghe/utils/norm.html +++ b/docs/linghe/utils/norm.html @@ -23,6 +23,8 @@  linghe.utils + @@ -118,11 +120,14 @@
    Arguments:
    Returns:
    -

    out: quantization data - scale: quantization scale - rms: Reciprocal of the root mean square of the input calculated over the last dimension. - transpose_output: quantization data of transposed gradient - transpose_scale: quantization scale of transposed gradient

    +
    @@ -160,5 +165,186 @@
    Returns:
    - + \ No newline at end of file diff --git a/docs/linghe/utils/rearange.html b/docs/linghe/utils/rearange.html index dab027f..ec9cc8e 100644 --- a/docs/linghe/utils/rearange.html +++ b/docs/linghe/utils/rearange.html @@ -23,6 +23,8 @@  linghe.utils + @@ -79,13 +81,196 @@
    Arguments:
    Returns:
    -

    y: output tensor - output_scales: output scales if scales is not None

    +
    - + \ No newline at end of file diff --git a/docs/linghe/utils/reduce.html b/docs/linghe/utils/reduce.html index d000c2d..0aeb0ae 100644 --- a/docs/linghe/utils/reduce.html +++ b/docs/linghe/utils/reduce.html @@ -23,6 +23,8 @@  linghe.utils + @@ -147,5 +149,186 @@
    Returns:
    - + \ No newline at end of file diff --git a/docs/linghe/utils/rope.html b/docs/linghe/utils/rope.html index df394a3..c44d6c5 100644 --- a/docs/linghe/utils/rope.html +++ b/docs/linghe/utils/rope.html @@ -23,6 +23,8 @@  linghe.utils + @@ -83,8 +85,10 @@
    Arguments:
    Returns:
    -

    qo: - ko:

    +
    @@ -121,9 +125,11 @@
    Arguments:
    Returns:
    -

    qo: shape [B, S, H, head_dim] - ko: shape [B, S, h, head_dim] - vo: shape [B, S, h, head_dim]

    +
    @@ -159,14 +165,197 @@
    Arguments:
    Returns:
    -

    dqkv: gradient of qkv - dqw: gradient of q_norm_weight - dkw: gradient of k_norm_weight

    +
    - + \ No newline at end of file diff --git a/docs/linghe/utils/scatter.html b/docs/linghe/utils/scatter.html index 68e5f39..360ddf8 100644 --- a/docs/linghe/utils/scatter.html +++ b/docs/linghe/utils/scatter.html @@ -23,6 +23,8 @@  linghe.utils + @@ -113,7 +115,7 @@
    Arguments:
    Returns:
    -

    outputs

    +

    output tensor

    @@ -142,13 +144,196 @@
    Arguments:
    Returns:
    -

    output: [num_tokens, hidden_size] - restore_probs: [num_tokens, num_experts]

    +
    - + \ No newline at end of file diff --git a/docs/linghe/utils/silu.html b/docs/linghe/utils/silu.html index cbb5131..75bc83a 100644 --- a/docs/linghe/utils/silu.html +++ b/docs/linghe/utils/silu.html @@ -23,6 +23,8 @@  linghe.utils + @@ -120,8 +122,10 @@
    Arguments:
    Returns:
    -

    dx: gradient of x - dw: gradient of weight

    +
    @@ -153,10 +157,12 @@
    Arguments:
    Returns:
    -

    out: quantized tensor - scale: quantization scale - transpose_output: quantized tensor of transposed output - transpose_scale: quantization scale of transposed output

    +
    @@ -185,10 +191,12 @@
    Arguments:
    Returns:
    -

    dx: quantized non-transposed gradient - dx_scale: scales of quantization non-transposed gradient - transpose_dx: quantized transposed gradient - transpose_dx_scale: scales of quantization transposed gradient

    +
    @@ -223,10 +231,12 @@
    Arguments:
    Returns:
    -

    out: quantized tensor - scale: quantization scale - transpose_output: quantized tensor of transposed output - transpose_scale: quantization scale of transposed output

    +
    @@ -258,16 +268,199 @@
    Arguments:
    Returns:
    -

    dx: quantized non-transposed gradient - dx_scale: scales of quantization non-transposed gradient - dw: gradient of weight - transpose_dx: quantized transposed gradient - transpose_dx_scale: scales of quantization transposed gradient

    +
    - + \ No newline at end of file diff --git a/docs/linghe/utils/transpose.html b/docs/linghe/utils/transpose.html index b278ac2..df81d84 100644 --- a/docs/linghe/utils/transpose.html +++ b/docs/linghe/utils/transpose.html @@ -23,6 +23,8 @@  linghe.utils + @@ -180,5 +182,186 @@
    Returns:
    - + \ No newline at end of file diff --git a/docs/search.js b/docs/search.js index 23741f0..e1491d0 100644 --- a/docs/search.js +++ b/docs/search.js @@ -1,6 +1,6 @@ window.pdocSearch = (function(){ /** elasticlunr - http://weixsong.github.io * Copyright (C) 2017 Oliver Nightingale * Copyright (C) 2017 Wei Song * MIT Licensed */!function(){function e(e){if(null===e||"object"!=typeof e)return e;var t=e.constructor();for(var n in e)e.hasOwnProperty(n)&&(t[n]=e[n]);return t}var t=function(e){var n=new t.Index;return n.pipeline.add(t.trimmer,t.stopWordFilter,t.stemmer),e&&e.call(n,n),n};t.version="0.9.5",lunr=t,t.utils={},t.utils.warn=function(e){return function(t){e.console&&console.warn&&console.warn(t)}}(this),t.utils.toString=function(e){return void 0===e||null===e?"":e.toString()},t.EventEmitter=function(){this.events={}},t.EventEmitter.prototype.addListener=function(){var e=Array.prototype.slice.call(arguments),t=e.pop(),n=e;if("function"!=typeof t)throw new TypeError("last argument must be a function");n.forEach(function(e){this.hasHandler(e)||(this.events[e]=[]),this.events[e].push(t)},this)},t.EventEmitter.prototype.removeListener=function(e,t){if(this.hasHandler(e)){var n=this.events[e].indexOf(t);-1!==n&&(this.events[e].splice(n,1),0==this.events[e].length&&delete this.events[e])}},t.EventEmitter.prototype.emit=function(e){if(this.hasHandler(e)){var t=Array.prototype.slice.call(arguments,1);this.events[e].forEach(function(e){e.apply(void 0,t)},this)}},t.EventEmitter.prototype.hasHandler=function(e){return e in this.events},t.tokenizer=function(e){if(!arguments.length||null===e||void 0===e)return[];if(Array.isArray(e)){var n=e.filter(function(e){return null===e||void 0===e?!1:!0});n=n.map(function(e){return t.utils.toString(e).toLowerCase()});var i=[];return n.forEach(function(e){var n=e.split(t.tokenizer.seperator);i=i.concat(n)},this),i}return e.toString().trim().toLowerCase().split(t.tokenizer.seperator)},t.tokenizer.defaultSeperator=/[\s\-]+/,t.tokenizer.seperator=t.tokenizer.defaultSeperator,t.tokenizer.setSeperator=function(e){null!==e&&void 0!==e&&"object"==typeof e&&(t.tokenizer.seperator=e)},t.tokenizer.resetSeperator=function(){t.tokenizer.seperator=t.tokenizer.defaultSeperator},t.tokenizer.getSeperator=function(){return t.tokenizer.seperator},t.Pipeline=function(){this._queue=[]},t.Pipeline.registeredFunctions={},t.Pipeline.registerFunction=function(e,n){n in t.Pipeline.registeredFunctions&&t.utils.warn("Overwriting existing registered function: "+n),e.label=n,t.Pipeline.registeredFunctions[n]=e},t.Pipeline.getRegisteredFunction=function(e){return e in t.Pipeline.registeredFunctions!=!0?null:t.Pipeline.registeredFunctions[e]},t.Pipeline.warnIfFunctionNotRegistered=function(e){var n=e.label&&e.label in this.registeredFunctions;n||t.utils.warn("Function is not registered with pipeline. This may cause problems when serialising the index.\n",e)},t.Pipeline.load=function(e){var n=new t.Pipeline;return e.forEach(function(e){var i=t.Pipeline.getRegisteredFunction(e);if(!i)throw new Error("Cannot load un-registered function: "+e);n.add(i)}),n},t.Pipeline.prototype.add=function(){var e=Array.prototype.slice.call(arguments);e.forEach(function(e){t.Pipeline.warnIfFunctionNotRegistered(e),this._queue.push(e)},this)},t.Pipeline.prototype.after=function(e,n){t.Pipeline.warnIfFunctionNotRegistered(n);var i=this._queue.indexOf(e);if(-1===i)throw new Error("Cannot find existingFn");this._queue.splice(i+1,0,n)},t.Pipeline.prototype.before=function(e,n){t.Pipeline.warnIfFunctionNotRegistered(n);var i=this._queue.indexOf(e);if(-1===i)throw new Error("Cannot find existingFn");this._queue.splice(i,0,n)},t.Pipeline.prototype.remove=function(e){var t=this._queue.indexOf(e);-1!==t&&this._queue.splice(t,1)},t.Pipeline.prototype.run=function(e){for(var t=[],n=e.length,i=this._queue.length,o=0;n>o;o++){for(var r=e[o],s=0;i>s&&(r=this._queue[s](r,o,e),void 0!==r&&null!==r);s++);void 0!==r&&null!==r&&t.push(r)}return t},t.Pipeline.prototype.reset=function(){this._queue=[]},t.Pipeline.prototype.get=function(){return this._queue},t.Pipeline.prototype.toJSON=function(){return this._queue.map(function(e){return t.Pipeline.warnIfFunctionNotRegistered(e),e.label})},t.Index=function(){this._fields=[],this._ref="id",this.pipeline=new t.Pipeline,this.documentStore=new t.DocumentStore,this.index={},this.eventEmitter=new t.EventEmitter,this._idfCache={},this.on("add","remove","update",function(){this._idfCache={}}.bind(this))},t.Index.prototype.on=function(){var e=Array.prototype.slice.call(arguments);return this.eventEmitter.addListener.apply(this.eventEmitter,e)},t.Index.prototype.off=function(e,t){return this.eventEmitter.removeListener(e,t)},t.Index.load=function(e){e.version!==t.version&&t.utils.warn("version mismatch: current "+t.version+" importing "+e.version);var n=new this;n._fields=e.fields,n._ref=e.ref,n.documentStore=t.DocumentStore.load(e.documentStore),n.pipeline=t.Pipeline.load(e.pipeline),n.index={};for(var i in e.index)n.index[i]=t.InvertedIndex.load(e.index[i]);return n},t.Index.prototype.addField=function(e){return this._fields.push(e),this.index[e]=new t.InvertedIndex,this},t.Index.prototype.setRef=function(e){return this._ref=e,this},t.Index.prototype.saveDocument=function(e){return this.documentStore=new t.DocumentStore(e),this},t.Index.prototype.addDoc=function(e,n){if(e){var n=void 0===n?!0:n,i=e[this._ref];this.documentStore.addDoc(i,e),this._fields.forEach(function(n){var o=this.pipeline.run(t.tokenizer(e[n]));this.documentStore.addFieldLength(i,n,o.length);var r={};o.forEach(function(e){e in r?r[e]+=1:r[e]=1},this);for(var s in r){var u=r[s];u=Math.sqrt(u),this.index[n].addToken(s,{ref:i,tf:u})}},this),n&&this.eventEmitter.emit("add",e,this)}},t.Index.prototype.removeDocByRef=function(e){if(e&&this.documentStore.isDocStored()!==!1&&this.documentStore.hasDoc(e)){var t=this.documentStore.getDoc(e);this.removeDoc(t,!1)}},t.Index.prototype.removeDoc=function(e,n){if(e){var n=void 0===n?!0:n,i=e[this._ref];this.documentStore.hasDoc(i)&&(this.documentStore.removeDoc(i),this._fields.forEach(function(n){var o=this.pipeline.run(t.tokenizer(e[n]));o.forEach(function(e){this.index[n].removeToken(e,i)},this)},this),n&&this.eventEmitter.emit("remove",e,this))}},t.Index.prototype.updateDoc=function(e,t){var t=void 0===t?!0:t;this.removeDocByRef(e[this._ref],!1),this.addDoc(e,!1),t&&this.eventEmitter.emit("update",e,this)},t.Index.prototype.idf=function(e,t){var n="@"+t+"/"+e;if(Object.prototype.hasOwnProperty.call(this._idfCache,n))return this._idfCache[n];var i=this.index[t].getDocFreq(e),o=1+Math.log(this.documentStore.length/(i+1));return this._idfCache[n]=o,o},t.Index.prototype.getFields=function(){return this._fields.slice()},t.Index.prototype.search=function(e,n){if(!e)return[];e="string"==typeof e?{any:e}:JSON.parse(JSON.stringify(e));var i=null;null!=n&&(i=JSON.stringify(n));for(var o=new t.Configuration(i,this.getFields()).get(),r={},s=Object.keys(e),u=0;u0&&t.push(e);for(var i in n)"docs"!==i&&"df"!==i&&this.expandToken(e+i,t,n[i]);return t},t.InvertedIndex.prototype.toJSON=function(){return{root:this.root}},t.Configuration=function(e,n){var e=e||"";if(void 0==n||null==n)throw new Error("fields should not be null");this.config={};var i;try{i=JSON.parse(e),this.buildUserConfig(i,n)}catch(o){t.utils.warn("user configuration parse failed, will use default configuration"),this.buildDefaultConfig(n)}},t.Configuration.prototype.buildDefaultConfig=function(e){this.reset(),e.forEach(function(e){this.config[e]={boost:1,bool:"OR",expand:!1}},this)},t.Configuration.prototype.buildUserConfig=function(e,n){var i="OR",o=!1;if(this.reset(),"bool"in e&&(i=e.bool||i),"expand"in e&&(o=e.expand||o),"fields"in e)for(var r in e.fields)if(n.indexOf(r)>-1){var s=e.fields[r],u=o;void 0!=s.expand&&(u=s.expand),this.config[r]={boost:s.boost||0===s.boost?s.boost:1,bool:s.bool||i,expand:u}}else t.utils.warn("field name in user configuration not found in index instance fields");else this.addAllFields2UserConfig(i,o,n)},t.Configuration.prototype.addAllFields2UserConfig=function(e,t,n){n.forEach(function(n){this.config[n]={boost:1,bool:e,expand:t}},this)},t.Configuration.prototype.get=function(){return this.config},t.Configuration.prototype.reset=function(){this.config={}},lunr.SortedSet=function(){this.length=0,this.elements=[]},lunr.SortedSet.load=function(e){var t=new this;return t.elements=e,t.length=e.length,t},lunr.SortedSet.prototype.add=function(){var e,t;for(e=0;e1;){if(r===e)return o;e>r&&(t=o),r>e&&(n=o),i=n-t,o=t+Math.floor(i/2),r=this.elements[o]}return r===e?o:-1},lunr.SortedSet.prototype.locationFor=function(e){for(var t=0,n=this.elements.length,i=n-t,o=t+Math.floor(i/2),r=this.elements[o];i>1;)e>r&&(t=o),r>e&&(n=o),i=n-t,o=t+Math.floor(i/2),r=this.elements[o];return r>e?o:e>r?o+1:void 0},lunr.SortedSet.prototype.intersect=function(e){for(var t=new lunr.SortedSet,n=0,i=0,o=this.length,r=e.length,s=this.elements,u=e.elements;;){if(n>o-1||i>r-1)break;s[n]!==u[i]?s[n]u[i]&&i++:(t.add(s[n]),n++,i++)}return t},lunr.SortedSet.prototype.clone=function(){var e=new lunr.SortedSet;return e.elements=this.toArray(),e.length=e.elements.length,e},lunr.SortedSet.prototype.union=function(e){var t,n,i;this.length>=e.length?(t=this,n=e):(t=e,n=this),i=t.clone();for(var o=0,r=n.toArray();o

    \n"}, {"fullname": "linghe.facade", "modulename": "linghe.facade", "kind": "module", "doc": "

    \n"}, {"fullname": "linghe.facade.add", "modulename": "linghe.facade.add", "kind": "module", "doc": "

    Copyright (c) Ant Financial Service Group and its affiliates.

    \n"}, {"fullname": "linghe.facade.add.InplaceAddFunction", "modulename": "linghe.facade.add", "qualname": "InplaceAddFunction", "kind": "class", "doc": "

    Base class to create custom autograd.Function.

    \n\n

    To create a custom autograd.Function, subclass this class and implement\nthe :meth:forward and :meth:backward static methods. Then, to use your custom\nop in the forward pass, call the class method apply. Do not call\n:meth:forward directly.

    \n\n

    To ensure correctness and best performance, make sure you are calling the\ncorrect methods on ctx and validating your backward function using\n:func:torch.autograd.gradcheck.

    \n\n

    See :ref:extending-autograd for more details on how to use this class.

    \n\n

    Examples::

    \n\n
    >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)\n>>> class Exp(Function):\n>>>     @staticmethod\n>>>     def forward(ctx, i):\n>>>         result = i.exp()\n>>>         ctx.save_for_backward(result)\n>>>         return result\n>>>\n>>>     @staticmethod\n>>>     def backward(ctx, grad_output):\n>>>         result, = ctx.saved_tensors\n>>>         return grad_output * result\n>>>\n>>> # Use it by calling the apply method:\n>>> # xdoctest: +SKIP\n>>> output = Exp.apply(input)\n
    \n", "bases": "torch.autograd.function.Function"}, {"fullname": "linghe.facade.add.InplaceAddFunction.forward", "modulename": "linghe.facade.add", "qualname": "InplaceAddFunction.forward", "kind": "function", "doc": "

    Define the forward of the custom autograd Function.

    \n\n

    This function is to be overridden by all subclasses.\nThere are two ways to define forward:

    \n\n

    Usage 1 (Combined forward and ctx)::

    \n\n
    @staticmethod\ndef forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:\n    pass\n
    \n\n
      \n
    • It must accept a context ctx as the first argument, followed by any\nnumber of arguments (tensors or other types).
    • \n
    • See :ref:combining-forward-context for more details
    • \n
    \n\n

    Usage 2 (Separate forward and ctx)::

    \n\n
    @staticmethod\ndef forward(*args: Any, **kwargs: Any) -> Any:\n    pass\n\n@staticmethod\ndef setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:\n    pass\n
    \n\n
      \n
    • The forward no longer accepts a ctx argument.
    • \n
    • Instead, you must also override the :meth:torch.autograd.Function.setup_context\nstaticmethod to handle setting up the ctx object.\noutput is the output of the forward, inputs are a Tuple of inputs\nto the forward.
    • \n
    • See :ref:extending-autograd for more details
    • \n
    \n\n

    The context can be used to store arbitrary data that can be then\nretrieved during the backward pass. Tensors should not be stored\ndirectly on ctx (though this is not currently enforced for\nbackward compatibility). Instead, tensors should be saved either with\n:func:ctx.save_for_backward if they are intended to be used in\nbackward (equivalently, vjp) or :func:ctx.save_for_forward\nif they are intended to be used for in jvp.

    \n", "signature": "(ctx, x, y):", "funcdef": "def"}, {"fullname": "linghe.facade.add.InplaceAddFunction.backward", "modulename": "linghe.facade.add", "qualname": "InplaceAddFunction.backward", "kind": "function", "doc": "

    Define a formula for differentiating the operation with backward mode automatic differentiation.

    \n\n

    This function is to be overridden by all subclasses.\n(Defining this function is equivalent to defining the vjp function.)

    \n\n

    It must accept a context :attr:ctx as the first argument, followed by\nas many outputs as the :func:forward returned (None will be passed in\nfor non tensor outputs of the forward function),\nand it should return as many tensors, as there were inputs to\n:func:forward. Each argument is the gradient w.r.t the given output,\nand each returned value should be the gradient w.r.t. the\ncorresponding input. If an input is not a Tensor or is a Tensor not\nrequiring grads, you can just pass None as a gradient for that input.

    \n\n

    The context can be used to retrieve tensors saved during the forward\npass. It also has an attribute :attr:ctx.needs_input_grad as a tuple\nof booleans representing whether each input needs gradient. E.g.,\n:func:backward will have ctx.needs_input_grad[0] = True if the\nfirst input to :func:forward needs gradient computed w.r.t. the\noutput.

    \n", "signature": "(ctx, grad_output):", "funcdef": "def"}, {"fullname": "linghe.facade.fp32_linear", "modulename": "linghe.facade.fp32_linear", "kind": "module", "doc": "

    Copyright (c) Ant Financial Service Group and its affiliates.

    \n"}, {"fullname": "linghe.facade.fp32_linear.FusedFp32GEMM", "modulename": "linghe.facade.fp32_linear", "qualname": "FusedFp32GEMM", "kind": "class", "doc": "

    Base class to create custom autograd.Function.

    \n\n

    To create a custom autograd.Function, subclass this class and implement\nthe :meth:forward and :meth:backward static methods. Then, to use your custom\nop in the forward pass, call the class method apply. Do not call\n:meth:forward directly.

    \n\n

    To ensure correctness and best performance, make sure you are calling the\ncorrect methods on ctx and validating your backward function using\n:func:torch.autograd.gradcheck.

    \n\n

    See :ref:extending-autograd for more details on how to use this class.

    \n\n

    Examples::

    \n\n
    >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)\n>>> class Exp(Function):\n>>>     @staticmethod\n>>>     def forward(ctx, i):\n>>>         result = i.exp()\n>>>         ctx.save_for_backward(result)\n>>>         return result\n>>>\n>>>     @staticmethod\n>>>     def backward(ctx, grad_output):\n>>>         result, = ctx.saved_tensors\n>>>         return grad_output * result\n>>>\n>>> # Use it by calling the apply method:\n>>> # xdoctest: +SKIP\n>>> output = Exp.apply(input)\n
    \n", "bases": "torch.autograd.function.Function"}, {"fullname": "linghe.facade.fp32_linear.FusedFp32GEMM.forward", "modulename": "linghe.facade.fp32_linear", "qualname": "FusedFp32GEMM.forward", "kind": "function", "doc": "

    Define the forward of the custom autograd Function.

    \n\n

    This function is to be overridden by all subclasses.\nThere are two ways to define forward:

    \n\n

    Usage 1 (Combined forward and ctx)::

    \n\n
    @staticmethod\ndef forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:\n    pass\n
    \n\n
      \n
    • It must accept a context ctx as the first argument, followed by any\nnumber of arguments (tensors or other types).
    • \n
    • See :ref:combining-forward-context for more details
    • \n
    \n\n

    Usage 2 (Separate forward and ctx)::

    \n\n
    @staticmethod\ndef forward(*args: Any, **kwargs: Any) -> Any:\n    pass\n\n@staticmethod\ndef setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:\n    pass\n
    \n\n
      \n
    • The forward no longer accepts a ctx argument.
    • \n
    • Instead, you must also override the :meth:torch.autograd.Function.setup_context\nstaticmethod to handle setting up the ctx object.\noutput is the output of the forward, inputs are a Tuple of inputs\nto the forward.
    • \n
    • See :ref:extending-autograd for more details
    • \n
    \n\n

    The context can be used to store arbitrary data that can be then\nretrieved during the backward pass. Tensors should not be stored\ndirectly on ctx (though this is not currently enforced for\nbackward compatibility). Instead, tensors should be saved either with\n:func:ctx.save_for_backward if they are intended to be used in\nbackward (equivalently, vjp) or :func:ctx.save_for_forward\nif they are intended to be used for in jvp.

    \n", "signature": "(ctx, input, weight):", "funcdef": "def"}, {"fullname": "linghe.facade.fp32_linear.FusedFp32GEMM.backward", "modulename": "linghe.facade.fp32_linear", "qualname": "FusedFp32GEMM.backward", "kind": "function", "doc": "

    Define a formula for differentiating the operation with backward mode automatic differentiation.

    \n\n

    This function is to be overridden by all subclasses.\n(Defining this function is equivalent to defining the vjp function.)

    \n\n

    It must accept a context :attr:ctx as the first argument, followed by\nas many outputs as the :func:forward returned (None will be passed in\nfor non tensor outputs of the forward function),\nand it should return as many tensors, as there were inputs to\n:func:forward. Each argument is the gradient w.r.t the given output,\nand each returned value should be the gradient w.r.t. the\ncorresponding input. If an input is not a Tensor or is a Tensor not\nrequiring grads, you can just pass None as a gradient for that input.

    \n\n

    The context can be used to retrieve tensors saved during the forward\npass. It also has an attribute :attr:ctx.needs_input_grad as a tuple\nof booleans representing whether each input needs gradient. E.g.,\n:func:backward will have ctx.needs_input_grad[0] = True if the\nfirst input to :func:forward needs gradient computed w.r.t. the\noutput.

    \n", "signature": "(ctx, grad_output):", "funcdef": "def"}, {"fullname": "linghe.facade.loss", "modulename": "linghe.facade.loss", "kind": "module", "doc": "

    Copyright (c) Ant Financial Service Group and its affiliates.

    \n"}, {"fullname": "linghe.facade.loss.SoftmaxCrossEntropyFunction", "modulename": "linghe.facade.loss", "qualname": "SoftmaxCrossEntropyFunction", "kind": "class", "doc": "

    Base class to create custom autograd.Function.

    \n\n

    To create a custom autograd.Function, subclass this class and implement\nthe :meth:forward and :meth:backward static methods. Then, to use your custom\nop in the forward pass, call the class method apply. Do not call\n:meth:forward directly.

    \n\n

    To ensure correctness and best performance, make sure you are calling the\ncorrect methods on ctx and validating your backward function using\n:func:torch.autograd.gradcheck.

    \n\n

    See :ref:extending-autograd for more details on how to use this class.

    \n\n

    Examples::

    \n\n
    >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)\n>>> class Exp(Function):\n>>>     @staticmethod\n>>>     def forward(ctx, i):\n>>>         result = i.exp()\n>>>         ctx.save_for_backward(result)\n>>>         return result\n>>>\n>>>     @staticmethod\n>>>     def backward(ctx, grad_output):\n>>>         result, = ctx.saved_tensors\n>>>         return grad_output * result\n>>>\n>>> # Use it by calling the apply method:\n>>> # xdoctest: +SKIP\n>>> output = Exp.apply(input)\n
    \n", "bases": "torch.autograd.function.Function"}, {"fullname": "linghe.facade.loss.SoftmaxCrossEntropyFunction.forward", "modulename": "linghe.facade.loss", "qualname": "SoftmaxCrossEntropyFunction.forward", "kind": "function", "doc": "

    Define the forward of the custom autograd Function.

    \n\n

    This function is to be overridden by all subclasses.\nThere are two ways to define forward:

    \n\n

    Usage 1 (Combined forward and ctx)::

    \n\n
    @staticmethod\ndef forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:\n    pass\n
    \n\n
      \n
    • It must accept a context ctx as the first argument, followed by any\nnumber of arguments (tensors or other types).
    • \n
    • See :ref:combining-forward-context for more details
    • \n
    \n\n

    Usage 2 (Separate forward and ctx)::

    \n\n
    @staticmethod\ndef forward(*args: Any, **kwargs: Any) -> Any:\n    pass\n\n@staticmethod\ndef setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:\n    pass\n
    \n\n
      \n
    • The forward no longer accepts a ctx argument.
    • \n
    • Instead, you must also override the :meth:torch.autograd.Function.setup_context\nstaticmethod to handle setting up the ctx object.\noutput is the output of the forward, inputs are a Tuple of inputs\nto the forward.
    • \n
    • See :ref:extending-autograd for more details
    • \n
    \n\n

    The context can be used to store arbitrary data that can be then\nretrieved during the backward pass. Tensors should not be stored\ndirectly on ctx (though this is not currently enforced for\nbackward compatibility). Instead, tensors should be saved either with\n:func:ctx.save_for_backward if they are intended to be used in\nbackward (equivalently, vjp) or :func:ctx.save_for_forward\nif they are intended to be used for in jvp.

    \n", "signature": "(ctx, logits, labels, inplace=False):", "funcdef": "def"}, {"fullname": "linghe.facade.loss.SoftmaxCrossEntropyFunction.backward", "modulename": "linghe.facade.loss", "qualname": "SoftmaxCrossEntropyFunction.backward", "kind": "function", "doc": "

    Define a formula for differentiating the operation with backward mode automatic differentiation.

    \n\n

    This function is to be overridden by all subclasses.\n(Defining this function is equivalent to defining the vjp function.)

    \n\n

    It must accept a context :attr:ctx as the first argument, followed by\nas many outputs as the :func:forward returned (None will be passed in\nfor non tensor outputs of the forward function),\nand it should return as many tensors, as there were inputs to\n:func:forward. Each argument is the gradient w.r.t the given output,\nand each returned value should be the gradient w.r.t. the\ncorresponding input. If an input is not a Tensor or is a Tensor not\nrequiring grads, you can just pass None as a gradient for that input.

    \n\n

    The context can be used to retrieve tensors saved during the forward\npass. It also has an attribute :attr:ctx.needs_input_grad as a tuple\nof booleans representing whether each input needs gradient. E.g.,\n:func:backward will have ctx.needs_input_grad[0] = True if the\nfirst input to :func:forward needs gradient computed w.r.t. the\noutput.

    \n", "signature": "(ctx, grad_output):", "funcdef": "def"}, {"fullname": "linghe.facade.loss.GradScalingFunction", "modulename": "linghe.facade.loss", "qualname": "GradScalingFunction", "kind": "class", "doc": "

    Base class to create custom autograd.Function.

    \n\n

    To create a custom autograd.Function, subclass this class and implement\nthe :meth:forward and :meth:backward static methods. Then, to use your custom\nop in the forward pass, call the class method apply. Do not call\n:meth:forward directly.

    \n\n

    To ensure correctness and best performance, make sure you are calling the\ncorrect methods on ctx and validating your backward function using\n:func:torch.autograd.gradcheck.

    \n\n

    See :ref:extending-autograd for more details on how to use this class.

    \n\n

    Examples::

    \n\n
    >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)\n>>> class Exp(Function):\n>>>     @staticmethod\n>>>     def forward(ctx, i):\n>>>         result = i.exp()\n>>>         ctx.save_for_backward(result)\n>>>         return result\n>>>\n>>>     @staticmethod\n>>>     def backward(ctx, grad_output):\n>>>         result, = ctx.saved_tensors\n>>>         return grad_output * result\n>>>\n>>> # Use it by calling the apply method:\n>>> # xdoctest: +SKIP\n>>> output = Exp.apply(input)\n
    \n", "bases": "torch.autograd.function.Function"}, {"fullname": "linghe.facade.loss.GradScalingFunction.forward", "modulename": "linghe.facade.loss", "qualname": "GradScalingFunction.forward", "kind": "function", "doc": "

    Define the forward of the custom autograd Function.

    \n\n

    This function is to be overridden by all subclasses.\nThere are two ways to define forward:

    \n\n

    Usage 1 (Combined forward and ctx)::

    \n\n
    @staticmethod\ndef forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:\n    pass\n
    \n\n
      \n
    • It must accept a context ctx as the first argument, followed by any\nnumber of arguments (tensors or other types).
    • \n
    • See :ref:combining-forward-context for more details
    • \n
    \n\n

    Usage 2 (Separate forward and ctx)::

    \n\n
    @staticmethod\ndef forward(*args: Any, **kwargs: Any) -> Any:\n    pass\n\n@staticmethod\ndef setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:\n    pass\n
    \n\n
      \n
    • The forward no longer accepts a ctx argument.
    • \n
    • Instead, you must also override the :meth:torch.autograd.Function.setup_context\nstaticmethod to handle setting up the ctx object.\noutput is the output of the forward, inputs are a Tuple of inputs\nto the forward.
    • \n
    • See :ref:extending-autograd for more details
    • \n
    \n\n

    The context can be used to store arbitrary data that can be then\nretrieved during the backward pass. Tensors should not be stored\ndirectly on ctx (though this is not currently enforced for\nbackward compatibility). Instead, tensors should be saved either with\n:func:ctx.save_for_backward if they are intended to be used in\nbackward (equivalently, vjp) or :func:ctx.save_for_forward\nif they are intended to be used for in jvp.

    \n", "signature": "(ctx, x, coef=0.2):", "funcdef": "def"}, {"fullname": "linghe.facade.loss.GradScalingFunction.backward", "modulename": "linghe.facade.loss", "qualname": "GradScalingFunction.backward", "kind": "function", "doc": "

    Define a formula for differentiating the operation with backward mode automatic differentiation.

    \n\n

    This function is to be overridden by all subclasses.\n(Defining this function is equivalent to defining the vjp function.)

    \n\n

    It must accept a context :attr:ctx as the first argument, followed by\nas many outputs as the :func:forward returned (None will be passed in\nfor non tensor outputs of the forward function),\nand it should return as many tensors, as there were inputs to\n:func:forward. Each argument is the gradient w.r.t the given output,\nand each returned value should be the gradient w.r.t. the\ncorresponding input. If an input is not a Tensor or is a Tensor not\nrequiring grads, you can just pass None as a gradient for that input.

    \n\n

    The context can be used to retrieve tensors saved during the forward\npass. It also has an attribute :attr:ctx.needs_input_grad as a tuple\nof booleans representing whether each input needs gradient. E.g.,\n:func:backward will have ctx.needs_input_grad[0] = True if the\nfirst input to :func:forward needs gradient computed w.r.t. the\noutput.

    \n", "signature": "(ctx, grad_output):", "funcdef": "def"}, {"fullname": "linghe.facade.norm", "modulename": "linghe.facade.norm", "kind": "module", "doc": "

    Copyright (c) Ant Financial Service Group and its affiliates.

    \n"}, {"fullname": "linghe.facade.norm.RMSNormFunction", "modulename": "linghe.facade.norm", "qualname": "RMSNormFunction", "kind": "class", "doc": "

    Base class to create custom autograd.Function.

    \n\n

    To create a custom autograd.Function, subclass this class and implement\nthe :meth:forward and :meth:backward static methods. Then, to use your custom\nop in the forward pass, call the class method apply. Do not call\n:meth:forward directly.

    \n\n

    To ensure correctness and best performance, make sure you are calling the\ncorrect methods on ctx and validating your backward function using\n:func:torch.autograd.gradcheck.

    \n\n

    See :ref:extending-autograd for more details on how to use this class.

    \n\n

    Examples::

    \n\n
    >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)\n>>> class Exp(Function):\n>>>     @staticmethod\n>>>     def forward(ctx, i):\n>>>         result = i.exp()\n>>>         ctx.save_for_backward(result)\n>>>         return result\n>>>\n>>>     @staticmethod\n>>>     def backward(ctx, grad_output):\n>>>         result, = ctx.saved_tensors\n>>>         return grad_output * result\n>>>\n>>> # Use it by calling the apply method:\n>>> # xdoctest: +SKIP\n>>> output = Exp.apply(input)\n
    \n", "bases": "torch.autograd.function.Function"}, {"fullname": "linghe.facade.norm.RMSNormFunction.forward", "modulename": "linghe.facade.norm", "qualname": "RMSNormFunction.forward", "kind": "function", "doc": "

    Define the forward of the custom autograd Function.

    \n\n

    This function is to be overridden by all subclasses.\nThere are two ways to define forward:

    \n\n

    Usage 1 (Combined forward and ctx)::

    \n\n
    @staticmethod\ndef forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:\n    pass\n
    \n\n
      \n
    • It must accept a context ctx as the first argument, followed by any\nnumber of arguments (tensors or other types).
    • \n
    • See :ref:combining-forward-context for more details
    • \n
    \n\n

    Usage 2 (Separate forward and ctx)::

    \n\n
    @staticmethod\ndef forward(*args: Any, **kwargs: Any) -> Any:\n    pass\n\n@staticmethod\ndef setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:\n    pass\n
    \n\n
      \n
    • The forward no longer accepts a ctx argument.
    • \n
    • Instead, you must also override the :meth:torch.autograd.Function.setup_context\nstaticmethod to handle setting up the ctx object.\noutput is the output of the forward, inputs are a Tuple of inputs\nto the forward.
    • \n
    • See :ref:extending-autograd for more details
    • \n
    \n\n

    The context can be used to store arbitrary data that can be then\nretrieved during the backward pass. Tensors should not be stored\ndirectly on ctx (though this is not currently enforced for\nbackward compatibility). Instead, tensors should be saved either with\n:func:ctx.save_for_backward if they are intended to be used in\nbackward (equivalently, vjp) or :func:ctx.save_for_forward\nif they are intended to be used for in jvp.

    \n", "signature": "(ctx, x, weight, eps=1e-06):", "funcdef": "def"}, {"fullname": "linghe.facade.norm.RMSNormFunction.backward", "modulename": "linghe.facade.norm", "qualname": "RMSNormFunction.backward", "kind": "function", "doc": "

    Define a formula for differentiating the operation with backward mode automatic differentiation.

    \n\n

    This function is to be overridden by all subclasses.\n(Defining this function is equivalent to defining the vjp function.)

    \n\n

    It must accept a context :attr:ctx as the first argument, followed by\nas many outputs as the :func:forward returned (None will be passed in\nfor non tensor outputs of the forward function),\nand it should return as many tensors, as there were inputs to\n:func:forward. Each argument is the gradient w.r.t the given output,\nand each returned value should be the gradient w.r.t. the\ncorresponding input. If an input is not a Tensor or is a Tensor not\nrequiring grads, you can just pass None as a gradient for that input.

    \n\n

    The context can be used to retrieve tensors saved during the forward\npass. It also has an attribute :attr:ctx.needs_input_grad as a tuple\nof booleans representing whether each input needs gradient. E.g.,\n:func:backward will have ctx.needs_input_grad[0] = True if the\nfirst input to :func:forward needs gradient computed w.r.t. the\noutput.

    \n", "signature": "(ctx, dy):", "funcdef": "def"}, {"fullname": "linghe.facade.norm.GroupNormGateFunction", "modulename": "linghe.facade.norm", "qualname": "GroupNormGateFunction", "kind": "class", "doc": "

    Base class to create custom autograd.Function.

    \n\n

    To create a custom autograd.Function, subclass this class and implement\nthe :meth:forward and :meth:backward static methods. Then, to use your custom\nop in the forward pass, call the class method apply. Do not call\n:meth:forward directly.

    \n\n

    To ensure correctness and best performance, make sure you are calling the\ncorrect methods on ctx and validating your backward function using\n:func:torch.autograd.gradcheck.

    \n\n

    See :ref:extending-autograd for more details on how to use this class.

    \n\n

    Examples::

    \n\n
    >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)\n>>> class Exp(Function):\n>>>     @staticmethod\n>>>     def forward(ctx, i):\n>>>         result = i.exp()\n>>>         ctx.save_for_backward(result)\n>>>         return result\n>>>\n>>>     @staticmethod\n>>>     def backward(ctx, grad_output):\n>>>         result, = ctx.saved_tensors\n>>>         return grad_output * result\n>>>\n>>> # Use it by calling the apply method:\n>>> # xdoctest: +SKIP\n>>> output = Exp.apply(input)\n
    \n", "bases": "torch.autograd.function.Function"}, {"fullname": "linghe.facade.norm.GroupNormGateFunction.forward", "modulename": "linghe.facade.norm", "qualname": "GroupNormGateFunction.forward", "kind": "function", "doc": "

    Define the forward of the custom autograd Function.

    \n\n

    This function is to be overridden by all subclasses.\nThere are two ways to define forward:

    \n\n

    Usage 1 (Combined forward and ctx)::

    \n\n
    @staticmethod\ndef forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:\n    pass\n
    \n\n
      \n
    • It must accept a context ctx as the first argument, followed by any\nnumber of arguments (tensors or other types).
    • \n
    • See :ref:combining-forward-context for more details
    • \n
    \n\n

    Usage 2 (Separate forward and ctx)::

    \n\n
    @staticmethod\ndef forward(*args: Any, **kwargs: Any) -> Any:\n    pass\n\n@staticmethod\ndef setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:\n    pass\n
    \n\n
      \n
    • The forward no longer accepts a ctx argument.
    • \n
    • Instead, you must also override the :meth:torch.autograd.Function.setup_context\nstaticmethod to handle setting up the ctx object.\noutput is the output of the forward, inputs are a Tuple of inputs\nto the forward.
    • \n
    • See :ref:extending-autograd for more details
    • \n
    \n\n

    The context can be used to store arbitrary data that can be then\nretrieved during the backward pass. Tensors should not be stored\ndirectly on ctx (though this is not currently enforced for\nbackward compatibility). Instead, tensors should be saved either with\n:func:ctx.save_for_backward if they are intended to be used in\nbackward (equivalently, vjp) or :func:ctx.save_for_forward\nif they are intended to be used for in jvp.

    \n", "signature": "(ctx, x, gate, weight, eps=1e-06, group_size=4):", "funcdef": "def"}, {"fullname": "linghe.facade.norm.GroupNormGateFunction.backward", "modulename": "linghe.facade.norm", "qualname": "GroupNormGateFunction.backward", "kind": "function", "doc": "

    Define a formula for differentiating the operation with backward mode automatic differentiation.

    \n\n

    This function is to be overridden by all subclasses.\n(Defining this function is equivalent to defining the vjp function.)

    \n\n

    It must accept a context :attr:ctx as the first argument, followed by\nas many outputs as the :func:forward returned (None will be passed in\nfor non tensor outputs of the forward function),\nand it should return as many tensors, as there were inputs to\n:func:forward. Each argument is the gradient w.r.t the given output,\nand each returned value should be the gradient w.r.t. the\ncorresponding input. If an input is not a Tensor or is a Tensor not\nrequiring grads, you can just pass None as a gradient for that input.

    \n\n

    The context can be used to retrieve tensors saved during the forward\npass. It also has an attribute :attr:ctx.needs_input_grad as a tuple\nof booleans representing whether each input needs gradient. E.g.,\n:func:backward will have ctx.needs_input_grad[0] = True if the\nfirst input to :func:forward needs gradient computed w.r.t. the\noutput.

    \n", "signature": "(ctx, dy):", "funcdef": "def"}, {"fullname": "linghe.facade.rope", "modulename": "linghe.facade.rope", "kind": "module", "doc": "

    Copyright (c) Ant Financial Service Group and its affiliates.

    \n"}, {"fullname": "linghe.facade.rope.QkNormHalfRopeFunction", "modulename": "linghe.facade.rope", "qualname": "QkNormHalfRopeFunction", "kind": "class", "doc": "

    Base class to create custom autograd.Function.

    \n\n

    To create a custom autograd.Function, subclass this class and implement\nthe :meth:forward and :meth:backward static methods. Then, to use your custom\nop in the forward pass, call the class method apply. Do not call\n:meth:forward directly.

    \n\n

    To ensure correctness and best performance, make sure you are calling the\ncorrect methods on ctx and validating your backward function using\n:func:torch.autograd.gradcheck.

    \n\n

    See :ref:extending-autograd for more details on how to use this class.

    \n\n

    Examples::

    \n\n
    >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)\n>>> class Exp(Function):\n>>>     @staticmethod\n>>>     def forward(ctx, i):\n>>>         result = i.exp()\n>>>         ctx.save_for_backward(result)\n>>>         return result\n>>>\n>>>     @staticmethod\n>>>     def backward(ctx, grad_output):\n>>>         result, = ctx.saved_tensors\n>>>         return grad_output * result\n>>>\n>>> # Use it by calling the apply method:\n>>> # xdoctest: +SKIP\n>>> output = Exp.apply(input)\n
    \n", "bases": "torch.autograd.function.Function"}, {"fullname": "linghe.facade.rope.QkNormHalfRopeFunction.forward", "modulename": "linghe.facade.rope", "qualname": "QkNormHalfRopeFunction.forward", "kind": "function", "doc": "

    Define the forward of the custom autograd Function.

    \n\n

    This function is to be overridden by all subclasses.\nThere are two ways to define forward:

    \n\n

    Usage 1 (Combined forward and ctx)::

    \n\n
    @staticmethod\ndef forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:\n    pass\n
    \n\n
      \n
    • It must accept a context ctx as the first argument, followed by any\nnumber of arguments (tensors or other types).
    • \n
    • See :ref:combining-forward-context for more details
    • \n
    \n\n

    Usage 2 (Separate forward and ctx)::

    \n\n
    @staticmethod\ndef forward(*args: Any, **kwargs: Any) -> Any:\n    pass\n\n@staticmethod\ndef setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:\n    pass\n
    \n\n
      \n
    • The forward no longer accepts a ctx argument.
    • \n
    • Instead, you must also override the :meth:torch.autograd.Function.setup_context\nstaticmethod to handle setting up the ctx object.\noutput is the output of the forward, inputs are a Tuple of inputs\nto the forward.
    • \n
    • See :ref:extending-autograd for more details
    • \n
    \n\n

    The context can be used to store arbitrary data that can be then\nretrieved during the backward pass. Tensors should not be stored\ndirectly on ctx (though this is not currently enforced for\nbackward compatibility). Instead, tensors should be saved either with\n:func:ctx.save_for_backward if they are intended to be used in\nbackward (equivalently, vjp) or :func:ctx.save_for_forward\nif they are intended to be used for in jvp.

    \n", "signature": "(ctx, qkv, q_norm_weight, k_norm_weight, freqs, H=32, h=4, eps=1e-06):", "funcdef": "def"}, {"fullname": "linghe.facade.rope.QkNormHalfRopeFunction.backward", "modulename": "linghe.facade.rope", "qualname": "QkNormHalfRopeFunction.backward", "kind": "function", "doc": "

    Define a formula for differentiating the operation with backward mode automatic differentiation.

    \n\n

    This function is to be overridden by all subclasses.\n(Defining this function is equivalent to defining the vjp function.)

    \n\n

    It must accept a context :attr:ctx as the first argument, followed by\nas many outputs as the :func:forward returned (None will be passed in\nfor non tensor outputs of the forward function),\nand it should return as many tensors, as there were inputs to\n:func:forward. Each argument is the gradient w.r.t the given output,\nand each returned value should be the gradient w.r.t. the\ncorresponding input. If an input is not a Tensor or is a Tensor not\nrequiring grads, you can just pass None as a gradient for that input.

    \n\n

    The context can be used to retrieve tensors saved during the forward\npass. It also has an attribute :attr:ctx.needs_input_grad as a tuple\nof booleans representing whether each input needs gradient. E.g.,\n:func:backward will have ctx.needs_input_grad[0] = True if the\nfirst input to :func:forward needs gradient computed w.r.t. the\noutput.

    \n", "signature": "(ctx, grad_q, grad_k, grad_v):", "funcdef": "def"}, {"fullname": "linghe.facade.transpose", "modulename": "linghe.facade.transpose", "kind": "module", "doc": "

    Copyright (c) Ant Financial Service Group and its affiliates.

    \n"}, {"fullname": "linghe.facade.transpose.TransposeDim01Function", "modulename": "linghe.facade.transpose", "qualname": "TransposeDim01Function", "kind": "class", "doc": "

    Base class to create custom autograd.Function.

    \n\n

    To create a custom autograd.Function, subclass this class and implement\nthe :meth:forward and :meth:backward static methods. Then, to use your custom\nop in the forward pass, call the class method apply. Do not call\n:meth:forward directly.

    \n\n

    To ensure correctness and best performance, make sure you are calling the\ncorrect methods on ctx and validating your backward function using\n:func:torch.autograd.gradcheck.

    \n\n

    See :ref:extending-autograd for more details on how to use this class.

    \n\n

    Examples::

    \n\n
    >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)\n>>> class Exp(Function):\n>>>     @staticmethod\n>>>     def forward(ctx, i):\n>>>         result = i.exp()\n>>>         ctx.save_for_backward(result)\n>>>         return result\n>>>\n>>>     @staticmethod\n>>>     def backward(ctx, grad_output):\n>>>         result, = ctx.saved_tensors\n>>>         return grad_output * result\n>>>\n>>> # Use it by calling the apply method:\n>>> # xdoctest: +SKIP\n>>> output = Exp.apply(input)\n
    \n", "bases": "torch.autograd.function.Function"}, {"fullname": "linghe.facade.transpose.TransposeDim01Function.forward", "modulename": "linghe.facade.transpose", "qualname": "TransposeDim01Function.forward", "kind": "function", "doc": "

    Define the forward of the custom autograd Function.

    \n\n

    This function is to be overridden by all subclasses.\nThere are two ways to define forward:

    \n\n

    Usage 1 (Combined forward and ctx)::

    \n\n
    @staticmethod\ndef forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:\n    pass\n
    \n\n
      \n
    • It must accept a context ctx as the first argument, followed by any\nnumber of arguments (tensors or other types).
    • \n
    • See :ref:combining-forward-context for more details
    • \n
    \n\n

    Usage 2 (Separate forward and ctx)::

    \n\n
    @staticmethod\ndef forward(*args: Any, **kwargs: Any) -> Any:\n    pass\n\n@staticmethod\ndef setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:\n    pass\n
    \n\n
      \n
    • The forward no longer accepts a ctx argument.
    • \n
    • Instead, you must also override the :meth:torch.autograd.Function.setup_context\nstaticmethod to handle setting up the ctx object.\noutput is the output of the forward, inputs are a Tuple of inputs\nto the forward.
    • \n
    • See :ref:extending-autograd for more details
    • \n
    \n\n

    The context can be used to store arbitrary data that can be then\nretrieved during the backward pass. Tensors should not be stored\ndirectly on ctx (though this is not currently enforced for\nbackward compatibility). Instead, tensors should be saved either with\n:func:ctx.save_for_backward if they are intended to be used in\nbackward (equivalently, vjp) or :func:ctx.save_for_forward\nif they are intended to be used for in jvp.

    \n", "signature": "(ctx, x):", "funcdef": "def"}, {"fullname": "linghe.facade.transpose.TransposeDim01Function.backward", "modulename": "linghe.facade.transpose", "qualname": "TransposeDim01Function.backward", "kind": "function", "doc": "

    Define a formula for differentiating the operation with backward mode automatic differentiation.

    \n\n

    This function is to be overridden by all subclasses.\n(Defining this function is equivalent to defining the vjp function.)

    \n\n

    It must accept a context :attr:ctx as the first argument, followed by\nas many outputs as the :func:forward returned (None will be passed in\nfor non tensor outputs of the forward function),\nand it should return as many tensors, as there were inputs to\n:func:forward. Each argument is the gradient w.r.t the given output,\nand each returned value should be the gradient w.r.t. the\ncorresponding input. If an input is not a Tensor or is a Tensor not\nrequiring grads, you can just pass None as a gradient for that input.

    \n\n

    The context can be used to retrieve tensors saved during the forward\npass. It also has an attribute :attr:ctx.needs_input_grad as a tuple\nof booleans representing whether each input needs gradient. E.g.,\n:func:backward will have ctx.needs_input_grad[0] = True if the\nfirst input to :func:forward needs gradient computed w.r.t. the\noutput.

    \n", "signature": "(ctx, grad_output):", "funcdef": "def"}, {"fullname": "linghe.gemm", "modulename": "linghe.gemm", "kind": "module", "doc": "

    \n"}, {"fullname": "linghe.gemm.fp32_gemm", "modulename": "linghe.gemm.fp32_gemm", "kind": "module", "doc": "

    Copyright (c) Ant Financial Service Group and its affiliates.

    \n"}, {"fullname": "linghe.gemm.fp32_gemm.fp32_gemm_kernel", "modulename": "linghe.gemm.fp32_gemm", "qualname": "fp32_gemm_kernel", "kind": "function", "doc": "

    \n", "signature": "(\ta_ptr,\tb_ptr,\tc_ptr,\tM,\tN: int,\tK: int,\tBLOCK_SIZE_K: int,\tBLOCK_SIZE_M: int,\tBLOCK_SIZE_N: int):", "funcdef": "def"}, {"fullname": "linghe.gemm.fp32_gemm.triton_fp32_gemm", "modulename": "linghe.gemm.fp32_gemm", "qualname": "triton_fp32_gemm", "kind": "function", "doc": "

    \n", "signature": "(a: torch.Tensor, b: torch.Tensor):", "funcdef": "def"}, {"fullname": "linghe.gemm.fp32_gemm.scaled_fp32_gemm_kernel", "modulename": "linghe.gemm.fp32_gemm", "qualname": "scaled_fp32_gemm_kernel", "kind": "function", "doc": "

    \n", "signature": "(\ta_ptr,\tb_ptr,\tscale_ptr,\tc_ptr,\tM,\tN: int,\tK: int,\tBLOCK_SIZE_K: int,\tBLOCK_SIZE_M: int,\tBLOCK_SIZE_N: int):", "funcdef": "def"}, {"fullname": "linghe.gemm.fp32_gemm.triton_scaled_fp32_gemm", "modulename": "linghe.gemm.fp32_gemm", "qualname": "triton_scaled_fp32_gemm", "kind": "function", "doc": "

    \n", "signature": "(a: torch.Tensor, b: torch.Tensor, scale: torch.Tensor):", "funcdef": "def"}, {"fullname": "linghe.gemm.fp32_gemm.fp32_gemm_for_backward_kernel", "modulename": "linghe.gemm.fp32_gemm", "qualname": "fp32_gemm_for_backward_kernel", "kind": "function", "doc": "

    \n", "signature": "(\ta_ptr,\tb_ptr,\tc_ptr,\tM,\tN: int,\tK: int,\tACCUM: int,\tBLOCK_SIZE_K: int,\tBLOCK_SIZE_M: int,\tBLOCK_SIZE_N: int):", "funcdef": "def"}, {"fullname": "linghe.gemm.fp32_gemm.triton_fp32_gemm_for_backward", "modulename": "linghe.gemm.fp32_gemm", "qualname": "triton_fp32_gemm_for_backward", "kind": "function", "doc": "

    \n", "signature": "(\ta: torch.Tensor,\tb: torch.Tensor,\tc: Optional[torch.Tensor] = None,\taccum=False):", "funcdef": "def"}, {"fullname": "linghe.gemm.fp32_gemm.fp32_gemm_for_update_kernel", "modulename": "linghe.gemm.fp32_gemm", "qualname": "fp32_gemm_for_update_kernel", "kind": "function", "doc": "

    \n", "signature": "(\ta_ptr,\tb_ptr,\tc_ptr,\tM,\tN: int,\tK: int,\tBLOCK_SIZE_K: int,\tBLOCK_SIZE_M: int,\tBLOCK_SIZE_N: int):", "funcdef": "def"}, {"fullname": "linghe.gemm.fp32_gemm.triton_fp32_gemm_for_update", "modulename": "linghe.gemm.fp32_gemm", "qualname": "triton_fp32_gemm_for_update", "kind": "function", "doc": "

    \n", "signature": "(a: torch.Tensor, b: torch.Tensor):", "funcdef": "def"}, {"fullname": "linghe.gemm.fp32_gemm.scaled_fp32_gemm_for_update_kernel", "modulename": "linghe.gemm.fp32_gemm", "qualname": "scaled_fp32_gemm_for_update_kernel", "kind": "function", "doc": "

    \n", "signature": "(\ta_ptr,\tb_ptr,\tscale_ptr,\tc_ptr,\tM,\tN: int,\tK: int,\tBLOCK_SIZE_K: int,\tBLOCK_SIZE_M: int,\tBLOCK_SIZE_N: int):", "funcdef": "def"}, {"fullname": "linghe.gemm.fp32_gemm.triton_scaled_fp32_gemm_for_update", "modulename": "linghe.gemm.fp32_gemm", "qualname": "triton_scaled_fp32_gemm_for_update", "kind": "function", "doc": "

    \n", "signature": "(a: torch.Tensor, b: torch.Tensor, scale: torch.Tensor):", "funcdef": "def"}, {"fullname": "linghe.quant", "modulename": "linghe.quant", "kind": "module", "doc": "

    \n"}, {"fullname": "linghe.quant.block", "modulename": "linghe.quant.block", "kind": "module", "doc": "

    \n"}, {"fullname": "linghe.quant.block.block", "modulename": "linghe.quant.block.block", "kind": "module", "doc": "

    Copyright (c) Ant Financial Service Group and its affiliates.

    \n"}, {"fullname": "linghe.quant.block.block.block_quant_kernel", "modulename": "linghe.quant.block.block", "qualname": "block_quant_kernel", "kind": "function", "doc": "

    \n", "signature": "(x_ptr, y_ptr, s_ptr, M, N, BLOCK_SIZE: int, ROUND: int):", "funcdef": "def"}, {"fullname": "linghe.quant.block.block.block_quant", "modulename": "linghe.quant.block.block", "qualname": "block_quant", "kind": "function", "doc": "

    \n", "signature": "(x, block_size=128, round_scale=False):", "funcdef": "def"}, {"fullname": "linghe.quant.block.group", "modulename": "linghe.quant.block.group", "kind": "module", "doc": "

    Copyright (c) Ant Financial Service Group and its affiliates.

    \n"}, {"fullname": "linghe.quant.block.group.group_quant_kernel", "modulename": "linghe.quant.block.group", "qualname": "group_quant_kernel", "kind": "function", "doc": "

    \n", "signature": "(x_ptr, y_ptr, s_ptr, N, BLOCK_SIZE: int, K: int, ROUND: int):", "funcdef": "def"}, {"fullname": "linghe.quant.block.group.triton_group_quant", "modulename": "linghe.quant.block.group", "qualname": "triton_group_quant", "kind": "function", "doc": "

    \n", "signature": "(x, dtype=torch.float8_e4m3fn, group_size=128, round_scale=False):", "funcdef": "def"}, {"fullname": "linghe.quant.block.group.persist_group_quant_kernel", "modulename": "linghe.quant.block.group", "qualname": "persist_group_quant_kernel", "kind": "function", "doc": "

    \n", "signature": "(x_ptr, y_ptr, s_ptr, N, BLOCK_SIZE: int, B: int, K: int, ROUND: int):", "funcdef": "def"}, {"fullname": "linghe.quant.block.group.triton_persist_group_quant", "modulename": "linghe.quant.block.group", "qualname": "triton_persist_group_quant", "kind": "function", "doc": "

    \n", "signature": "(x, dtype=torch.float8_e4m3fn, group_size=128, round_scale=False):", "funcdef": "def"}, {"fullname": "linghe.quant.channel", "modulename": "linghe.quant.channel", "kind": "module", "doc": "

    \n"}, {"fullname": "linghe.quant.channel.channel", "modulename": "linghe.quant.channel.channel", "kind": "module", "doc": "

    Copyright (c) Ant Financial Service Group and its affiliates.

    \n"}, {"fullname": "linghe.quant.channel.channel.row_quant_kernel", "modulename": "linghe.quant.channel.channel", "qualname": "row_quant_kernel", "kind": "function", "doc": "

    \n", "signature": "(x_ptr, q_ptr, s_ptr, M, N, BLOCK_SIZE: int, ROUND: int):", "funcdef": "def"}, {"fullname": "linghe.quant.channel.channel.triton_row_quant", "modulename": "linghe.quant.channel.channel", "qualname": "triton_row_quant", "kind": "function", "doc": "

    \n", "signature": "(x, round_scale=False):", "funcdef": "def"}, {"fullname": "linghe.quant.channel.channel.deprecated_tokenwise_row_quant_kernel", "modulename": "linghe.quant.channel.channel", "qualname": "deprecated_tokenwise_row_quant_kernel", "kind": "function", "doc": "

    \n", "signature": "(x_ptr, out_ptr, scale_ptr, M, T: int, N: int, ROUND: int):", "funcdef": "def"}, {"fullname": "linghe.quant.channel.channel.triton_deprecated_tokenwise_row_quant", "modulename": "linghe.quant.channel.channel", "qualname": "triton_deprecated_tokenwise_row_quant", "kind": "function", "doc": "

    \n", "signature": "(x, out=None, scale=None, round_scale=False):", "funcdef": "def"}, {"fullname": "linghe.quant.channel.channel.tokenwise_row_quant_kernel", "modulename": "linghe.quant.channel.channel", "qualname": "tokenwise_row_quant_kernel", "kind": "function", "doc": "

    \n", "signature": "(x_ptr, out_ptr, scale_ptr, N: int, ROUND: int):", "funcdef": "def"}, {"fullname": "linghe.quant.channel.channel.triton_tokenwise_row_quant", "modulename": "linghe.quant.channel.channel", "qualname": "triton_tokenwise_row_quant", "kind": "function", "doc": "

    \n", "signature": "(x, out=None, scale=None, round_scale=False):", "funcdef": "def"}, {"fullname": "linghe.quant.channel.channel.transpose_row_quant_kernel", "modulename": "linghe.quant.channel.channel", "qualname": "transpose_row_quant_kernel", "kind": "function", "doc": "

    \n", "signature": "(x_ptr, q_ptr, s_ptr, M, N, H: int, W: int, ROUND: int):", "funcdef": "def"}, {"fullname": "linghe.quant.channel.channel.triton_transpose_row_quant", "modulename": "linghe.quant.channel.channel", "qualname": "triton_transpose_row_quant", "kind": "function", "doc": "

    \n", "signature": "(x, side=0, round_scale=False):", "funcdef": "def"}, {"fullname": "linghe.quant.channel.channel.triton_channel_quant_nt", "modulename": "linghe.quant.channel.channel", "qualname": "triton_channel_quant_nt", "kind": "function", "doc": "

    \n", "signature": "(x, w):", "funcdef": "def"}, {"fullname": "linghe.quant.channel.channel.triton_channel_quant_nn", "modulename": "linghe.quant.channel.channel", "qualname": "triton_channel_quant_nn", "kind": "function", "doc": "

    \n", "signature": "(y, w):", "funcdef": "def"}, {"fullname": "linghe.quant.channel.channel.triton_channel_quant_tn", "modulename": "linghe.quant.channel.channel", "qualname": "triton_channel_quant_tn", "kind": "function", "doc": "

    \n", "signature": "(y, x):", "funcdef": "def"}, {"fullname": "linghe.quant.channel.channel.channel_quant_forward", "modulename": "linghe.quant.channel.channel", "qualname": "channel_quant_forward", "kind": "function", "doc": "

    \n", "signature": "(x, w):", "funcdef": "def"}, {"fullname": "linghe.quant.channel.channel.channel_quant_backward", "modulename": "linghe.quant.channel.channel", "qualname": "channel_quant_backward", "kind": "function", "doc": "

    \n", "signature": "(y, w):", "funcdef": "def"}, {"fullname": "linghe.quant.channel.channel.channel_quant_update", "modulename": "linghe.quant.channel.channel", "qualname": "channel_quant_update", "kind": "function", "doc": "

    \n", "signature": "(y, x):", "funcdef": "def"}, {"fullname": "linghe.quant.channel.channel.fp8_channel_f_and_b", "modulename": "linghe.quant.channel.channel", "qualname": "fp8_channel_f_and_b", "kind": "function", "doc": "

    \n", "signature": "(x, w, y):", "funcdef": "def"}, {"fullname": "linghe.utils", "modulename": "linghe.utils", "kind": "module", "doc": "

    \n"}, {"fullname": "linghe.utils.add", "modulename": "linghe.utils.add", "kind": "module", "doc": "

    Copyright (c) Ant Financial Service Group and its affiliates.

    \n"}, {"fullname": "linghe.utils.add.inplace_add_kernel", "modulename": "linghe.utils.add", "qualname": "inplace_add_kernel", "kind": "function", "doc": "

    \n", "signature": "(x_ptr, y_ptr, M, N, H: int, W: int, EVEN: int, ACCUM: int):", "funcdef": "def"}, {"fullname": "linghe.utils.add.triton_inplace_add", "modulename": "linghe.utils.add", "qualname": "triton_inplace_add", "kind": "function", "doc": "

    inplace add y to x\nArgs:\n x: Tensor\n y: Tensor\n accum: whether accum y to x

    \n\n

    Returns: x += y if accum=True else x.copy_(y)

    \n", "signature": "(x: torch.Tensor, y: torch.Tensor, accum: bool = True):", "funcdef": "def"}, {"fullname": "linghe.utils.dot", "modulename": "linghe.utils.dot", "kind": "module", "doc": "

    Copyright (c) Ant Financial Service Group and its affiliates.

    \n"}, {"fullname": "linghe.utils.dot.dot_kernel", "modulename": "linghe.utils.dot", "qualname": "dot_kernel", "kind": "function", "doc": "

    \n", "signature": "(x_ptr, y_ptr, sum_ptr, M, N, H: int, W: int):", "funcdef": "def"}, {"fullname": "linghe.utils.dot.triton_dot", "modulename": "linghe.utils.dot", "qualname": "triton_dot", "kind": "function", "doc": "

    \n", "signature": "(x, y):", "funcdef": "def"}, {"fullname": "linghe.utils.dot.mix_precise_dot_kernel", "modulename": "linghe.utils.dot", "qualname": "mix_precise_dot_kernel", "kind": "function", "doc": "

    \n", "signature": "(\tx_ptr,\tq_ptr,\tsum_ptr,\tsmooth_scale_ptr,\tquant_scale_ptr,\tM,\tN,\tH: int,\tW: int):", "funcdef": "def"}, {"fullname": "linghe.utils.dot.triton_mix_precise_dot", "modulename": "linghe.utils.dot", "qualname": "triton_mix_precise_dot", "kind": "function", "doc": "

    \n", "signature": "(x, q, smooth_scale, quant_scale, reverse=False):", "funcdef": "def"}, {"fullname": "linghe.utils.gather", "modulename": "linghe.utils.gather", "kind": "module", "doc": "

    Copyright (c) Ant Financial Service Group and its affiliates.

    \n"}, {"fullname": "linghe.utils.gather.block_count_kernel", "modulename": "linghe.utils.gather", "qualname": "block_count_kernel", "kind": "function", "doc": "

    \n", "signature": "(map_ptr, count_ptr, M, B, T: int, b: int, E: int):", "funcdef": "def"}, {"fullname": "linghe.utils.gather.make_row_id_map_kernel", "modulename": "linghe.utils.gather", "qualname": "make_row_id_map_kernel", "kind": "function", "doc": "

    \n", "signature": "(map_ptr, count_ptr, output_ptr, M, B, P, T: int, b: int, E: int):", "funcdef": "def"}, {"fullname": "linghe.utils.gather.triton_make_row_id_map", "modulename": "linghe.utils.gather", "qualname": "triton_make_row_id_map", "kind": "function", "doc": "

    \n", "signature": "(routing_map: torch.Tensor, multiple_of: int = 1):", "funcdef": "def"}, {"fullname": "linghe.utils.gather.make_row_id_map_and_indices_kernel", "modulename": "linghe.utils.gather", "qualname": "make_row_id_map_and_indices_kernel", "kind": "function", "doc": "

    \n", "signature": "(\tmap_ptr,\tcount_ptr,\trow_map_ptr,\trow_indices_ptr,\tM,\tB,\tP,\tT: int,\tb: int,\tE: int):", "funcdef": "def"}, {"fullname": "linghe.utils.gather.triton_make_row_id_map_and_indices", "modulename": "linghe.utils.gather", "qualname": "triton_make_row_id_map_and_indices", "kind": "function", "doc": "

    \n", "signature": "(routing_map: torch.Tensor, num_out_tokens: int, multiple_of: int = 1):", "funcdef": "def"}, {"fullname": "linghe.utils.gather.index_select_kernel", "modulename": "linghe.utils.gather", "qualname": "index_select_kernel", "kind": "function", "doc": "

    \n", "signature": "(\tx_ptr,\tout_ptr,\tscale_ptr,\tscale_out_ptr,\tindex_ptr,\tM,\tT,\tN: int,\tSCALE: int):", "funcdef": "def"}, {"fullname": "linghe.utils.gather.triton_index_select", "modulename": "linghe.utils.gather", "qualname": "triton_index_select", "kind": "function", "doc": "

    \n", "signature": "(x, indices, scale=None, out=None, scale_out=None):", "funcdef": "def"}, {"fullname": "linghe.utils.gather.permute_with_mask_map_kernel", "modulename": "linghe.utils.gather", "qualname": "permute_with_mask_map_kernel", "kind": "function", "doc": "

    \n", "signature": "(\tdata_ptr,\tscale_ptr,\tprobs_ptr,\tmask_map_ptr,\toutput_data_ptr,\toutput_scale_ptr,\toutput_probs_ptr,\tnum_experts: int,\tN: int,\ths: int,\tSCALE: int,\tPROB: int):", "funcdef": "def"}, {"fullname": "linghe.utils.gather.fill_padded_token_with_zero_kernel", "modulename": "linghe.utils.gather", "qualname": "fill_padded_token_with_zero_kernel", "kind": "function", "doc": "

    \n", "signature": "(\tdata_ptr,\tscale_ptr,\tprobs_ptr,\tmax_indices_ptr,\ttoken_per_expert_ptr,\tN: int,\ths: int,\tSCALE: int,\tPROB: int):", "funcdef": "def"}, {"fullname": "linghe.utils.gather.triton_permute_with_mask_map", "modulename": "linghe.utils.gather", "qualname": "triton_permute_with_mask_map", "kind": "function", "doc": "

    \n", "signature": "(\tinp: torch.Tensor,\tscale: torch.Tensor,\tprobs: torch.Tensor,\trow_id_map: torch.Tensor,\tnum_out_tokens: int,\tcontiguous: bool = True,\ttokens_per_expert: Optional[torch.Tensor] = None):", "funcdef": "def"}, {"fullname": "linghe.utils.gather.batch_smooth_transpose_smooth_permute_kernel", "modulename": "linghe.utils.gather", "qualname": "batch_smooth_transpose_smooth_permute_kernel", "kind": "function", "doc": "

    \n", "signature": "(\tx_ptr,\tscale_ptr,\toss_ptr,\tss_ptr,\tindex_ptr,\tcount_ptr,\taccum_ptr,\tq_ptr,\tqs_ptr,\tN: int,\tE: int,\tH: int,\tW: int,\tSMOOTHED: int,\tROUND: int):", "funcdef": "def"}, {"fullname": "linghe.utils.gather.triton_batch_transpose_smooth_permute_with_indices", "modulename": "linghe.utils.gather", "qualname": "triton_batch_transpose_smooth_permute_with_indices", "kind": "function", "doc": "

    \n", "signature": "(\tx,\tscale,\torg_smooth_scale,\tsmooth_scales,\tindices,\ttoken_count_per_expert,\tsplits,\tx_q=None,\tx_scale=None,\tround_scale=False):", "funcdef": "def"}, {"fullname": "linghe.utils.gather.smooth_weighted_permute_with_indices_kernel", "modulename": "linghe.utils.gather", "qualname": "smooth_weighted_permute_with_indices_kernel", "kind": "function", "doc": "

    \n", "signature": "(\tgrads_ptr,\ttokens_ptr,\tq_ptr,\tss_ptr,\tqs_ptr,\tcount_ptr,\taccum_ptr,\tindex_ptr,\tsum_ptr,\tM,\tN: int,\tREVERSE: int,\tROUND: int):", "funcdef": "def"}, {"fullname": "linghe.utils.gather.triton_smooth_weighted_permute_with_indices", "modulename": "linghe.utils.gather", "qualname": "triton_smooth_weighted_permute_with_indices", "kind": "function", "doc": "

    \n", "signature": "(\tgrads,\ttokens,\tsmooth_scales,\ttoken_count_per_expert,\tindices,\tx_q=None,\tx_scale=None,\tx_sum=None,\treverse=False,\tround_scale=False):", "funcdef": "def"}, {"fullname": "linghe.utils.gather.smooth_permute_with_indices_kernel", "modulename": "linghe.utils.gather", "qualname": "smooth_permute_with_indices_kernel", "kind": "function", "doc": "

    \n", "signature": "(\tgrads_data_ptr,\tgrads_scale_ptr,\tq_ptr,\tss_ptr,\tqs_ptr,\tcount_ptr,\taccum_ptr,\tindex_ptr,\tN: int,\ths: int,\tREVERSE: int,\tROUND: int,\tGROUP: int):", "funcdef": "def"}, {"fullname": "linghe.utils.gather.triton_smooth_permute_with_indices", "modulename": "linghe.utils.gather", "qualname": "triton_smooth_permute_with_indices", "kind": "function", "doc": "

    \n", "signature": "(\tgrad_data,\tgrad_scale,\tsmooth_scales,\ttoken_count_per_expert,\tindices,\tx_q=None,\tx_scale=None,\treverse=False,\tround_scale=False):", "funcdef": "def"}, {"fullname": "linghe.utils.gather.smooth_permute_with_mask_map_kernel", "modulename": "linghe.utils.gather", "qualname": "smooth_permute_with_mask_map_kernel", "kind": "function", "doc": "

    \n", "signature": "(\tgrads_data_ptr,\tquant_data_ptr,\tmask_map_ptr,\tgrads_scale_ptr,\tsmooth_scale_ptr,\tquant_scale_ptr,\tM,\tT,\tN: int,\ths: int,\tREVERSE: int,\tROUND: int):", "funcdef": "def"}, {"fullname": "linghe.utils.gather.triton_smooth_permute_with_mask_map", "modulename": "linghe.utils.gather", "qualname": "triton_smooth_permute_with_mask_map", "kind": "function", "doc": "

    \n", "signature": "(\tinp: torch.Tensor,\trow_id_map: torch.Tensor,\tscale: torch.Tensor,\tnum_tokens: int,\tnum_experts: int,\tnum_out_tokens: int,\thidden_size: int,\tsmooth_scales: torch.Tensor,\treverse=True,\tround_scale=False):", "funcdef": "def"}, {"fullname": "linghe.utils.gather.deprecated_smooth_permute_with_mask_map_kernel", "modulename": "linghe.utils.gather", "qualname": "deprecated_smooth_permute_with_mask_map_kernel", "kind": "function", "doc": "

    \n", "signature": "(\tgrads_data_ptr,\tquant_data_ptr,\tmask_map_ptr,\tsmooth_scale_ptr,\tquant_scale_ptr,\tM,\tT,\tN: int,\tREVERSE: int,\tROUND: int):", "funcdef": "def"}, {"fullname": "linghe.utils.gather.triton_deprecated_smooth_permute_with_mask_map", "modulename": "linghe.utils.gather", "qualname": "triton_deprecated_smooth_permute_with_mask_map", "kind": "function", "doc": "

    \n", "signature": "(\tinp: torch.Tensor,\trow_id_map: torch.Tensor,\tnum_tokens: int,\tnum_experts: int,\tnum_out_tokens: int,\thidden_size: int,\tsmooth_scales: torch.Tensor,\treverse=True,\tround_scale=False):", "funcdef": "def"}, {"fullname": "linghe.utils.loss", "modulename": "linghe.utils.loss", "kind": "module", "doc": "

    Copyright (c) Ant Financial Service Group and its affiliates.

    \n"}, {"fullname": "linghe.utils.loss.softmax_cross_entropy_forward_kernel", "modulename": "linghe.utils.loss", "qualname": "softmax_cross_entropy_forward_kernel", "kind": "function", "doc": "

    \n", "signature": "(\tlogit_ptr,\tlabel_ptr,\tloss_ptr,\tsum_exp_ptr,\tmax_logit_ptr,\tN,\tB: int):", "funcdef": "def"}, {"fullname": "linghe.utils.loss.triton_softmax_cross_entropy_forward", "modulename": "linghe.utils.loss", "qualname": "triton_softmax_cross_entropy_forward", "kind": "function", "doc": "

    \n", "signature": "(logits, labels):", "funcdef": "def"}, {"fullname": "linghe.utils.loss.softmax_cross_entropy_backward_kernel", "modulename": "linghe.utils.loss", "qualname": "softmax_cross_entropy_backward_kernel", "kind": "function", "doc": "

    \n", "signature": "(\tlogit_ptr,\tlabel_ptr,\tsum_exp_ptr,\tmax_logit_ptr,\tinput_grad_ptr,\toutput_grad_ptr,\tN,\tB: int):", "funcdef": "def"}, {"fullname": "linghe.utils.loss.triton_softmax_cross_entropy_backward", "modulename": "linghe.utils.loss", "qualname": "triton_softmax_cross_entropy_backward", "kind": "function", "doc": "

    \n", "signature": "(logits, labels, sum_exp, max_logit, input_grad, output_grad=None):", "funcdef": "def"}, {"fullname": "linghe.utils.norm", "modulename": "linghe.utils.norm", "kind": "module", "doc": "

    \n"}, {"fullname": "linghe.utils.norm.rms_norm_forward_kernel", "modulename": "linghe.utils.norm", "qualname": "rms_norm_forward_kernel", "kind": "function", "doc": "

    \n", "signature": "(x_ptr, weight_ptr, out_ptr, eps, M, T, N: int, W: int):", "funcdef": "def"}, {"fullname": "linghe.utils.norm.triton_rms_norm_forward", "modulename": "linghe.utils.norm", "qualname": "triton_rms_norm_forward", "kind": "function", "doc": "

    \n", "signature": "(x, weight, eps=1e-06, out=None):", "funcdef": "def"}, {"fullname": "linghe.utils.norm.rms_norm_backward_kernel", "modulename": "linghe.utils.norm", "qualname": "rms_norm_backward_kernel", "kind": "function", "doc": "

    \n", "signature": "(\tgrad_output_ptr,\tx_ptr,\tw_ptr,\tdx_ptr,\tdw_ptr,\teps,\tM,\tT,\tN: int,\tW: int):", "funcdef": "def"}, {"fullname": "linghe.utils.norm.triton_rms_norm_backward", "modulename": "linghe.utils.norm", "qualname": "triton_rms_norm_backward", "kind": "function", "doc": "

    \n", "signature": "(grad_output, x, w, eps=1e-06):", "funcdef": "def"}, {"fullname": "linghe.utils.norm.rms_norm_and_block_quant_forward_kernel", "modulename": "linghe.utils.norm", "qualname": "rms_norm_and_block_quant_forward_kernel", "kind": "function", "doc": "

    \n", "signature": "(\tx_ptr,\tweight_ptr,\tout_ptr,\tscale_ptr,\ttranspose_output_ptr,\ttranspose_scale_ptr,\trms_ptr,\teps,\tM,\tT: int,\tN: int,\tnb: int,\tW: int,\tH: int,\tROUND: int):", "funcdef": "def"}, {"fullname": "linghe.utils.norm.rms_norm_and_block_quant_forward_n_kernel", "modulename": "linghe.utils.norm", "qualname": "rms_norm_and_block_quant_forward_n_kernel", "kind": "function", "doc": "

    \n", "signature": "(\tx_ptr,\tweight_ptr,\tout_ptr,\tscale_ptr,\trms_ptr,\teps,\tM: int,\tT: int,\tN: int,\tnb: int,\tW: int,\tROUND: int):", "funcdef": "def"}, {"fullname": "linghe.utils.norm.rms_norm_and_block_quant_forward_t_kernel", "modulename": "linghe.utils.norm", "qualname": "rms_norm_and_block_quant_forward_t_kernel", "kind": "function", "doc": "

    \n", "signature": "(\tx_ptr,\tweight_ptr,\ttranspose_output_ptr,\ttranspose_scale_ptr,\trms_ptr,\tM,\tN,\tW: int,\tROUND: int):", "funcdef": "def"}, {"fullname": "linghe.utils.norm.triton_rms_norm_and_block_quant_forward", "modulename": "linghe.utils.norm", "qualname": "triton_rms_norm_and_block_quant_forward", "kind": "function", "doc": "

    Fused RMSNorm forward and block quantization.\nArgs:\n x: Input tensor, shape [M, N]\n weight: RMSNorm weight, shape [N]\n eps: epsilon value for L2 normalization.\n out: output of quantization data\n scale: output of quantization scale.\n rms: output of rms\n round_scale: Set whether to force power of 2 scales.\n output_mode: one of {0, 1, 2}.\n 0: only output non-transpose tensor\n 1: only output transposed tensor\n 2: return both\nReturns:\n out: quantization data\n scale: quantization scale\n rms: Reciprocal of the root mean square of the input calculated over the last dimension.\n transpose_output: quantization data of transposed gradient\n transpose_scale: quantization scale of transposed gradient

    \n", "signature": "(\tx: torch.Tensor,\tweight: torch.Tensor,\teps: float = 1e-06,\tout: Optional[torch.Tensor] = None,\tscale: Optional[torch.Tensor] = None,\trms: Optional[torch.Tensor] = None,\tround_scale: bool = False,\toutput_mode: int = 2):", "funcdef": "def"}, {"fullname": "linghe.utils.norm.group_norm_gate_forward_kernel", "modulename": "linghe.utils.norm", "qualname": "group_norm_gate_forward_kernel", "kind": "function", "doc": "

    \n", "signature": "(\tx_ptr,\tgate_ptr,\tweight_ptr,\tout_ptr,\teps,\tbs,\tlength,\tDIM: int,\tD: int,\tGROUP_SIZE: int):", "funcdef": "def"}, {"fullname": "linghe.utils.norm.triton_group_norm_gate_forward", "modulename": "linghe.utils.norm", "qualname": "triton_group_norm_gate_forward", "kind": "function", "doc": "

    norm and gate in linear attention\nArgs:\n x:\n gate:\n weight:\n eps:\n group_size:

    \n\n

    Returns:

    \n", "signature": "(x: torch.Tensor, gate, weight, eps=1e-06, group_size=4):", "funcdef": "def"}, {"fullname": "linghe.utils.norm.group_rms_gate_backward_kernel", "modulename": "linghe.utils.norm", "qualname": "group_rms_gate_backward_kernel", "kind": "function", "doc": "

    \n", "signature": "(\tgrad_output_ptr,\tx_ptr,\tgate_ptr,\tw_ptr,\tdx_ptr,\tdg_ptr,\tdw_ptr,\teps,\tbs,\tlength,\tDIM: int,\tD: int,\tGROUP_SIZE: int,\tT: int):", "funcdef": "def"}, {"fullname": "linghe.utils.norm.triton_group_norm_gate_backward", "modulename": "linghe.utils.norm", "qualname": "triton_group_norm_gate_backward", "kind": "function", "doc": "

    \n", "signature": "(grad_output, x, gate, weight, eps=1e-06, group_size=4):", "funcdef": "def"}, {"fullname": "linghe.utils.rearange", "modulename": "linghe.utils.rearange", "kind": "module", "doc": "

    Copyright (c) Ant Financial Service Group and its affiliates.

    \n"}, {"fullname": "linghe.utils.rearange.split_and_cat_kernel", "modulename": "linghe.utils.rearange", "qualname": "split_and_cat_kernel", "kind": "function", "doc": "

    \n", "signature": "(\tx_ptr,\ty_ptr,\tscale_ptr,\tscale_output_ptr,\tcount_ptr,\taccum_ptr,\trev_accum_ptr,\tindex_ptr,\tM,\tN: int,\tSCALE: int,\tK: int):", "funcdef": "def"}, {"fullname": "linghe.utils.rearange.triton_split_and_cat", "modulename": "linghe.utils.rearange", "qualname": "triton_split_and_cat", "kind": "function", "doc": "

    \n", "signature": "(x, counts, indices, scales=None):", "funcdef": "def"}, {"fullname": "linghe.utils.reduce", "modulename": "linghe.utils.reduce", "kind": "module", "doc": "

    Copyright (c) Ant Financial Service Group and its affiliates.

    \n"}, {"fullname": "linghe.utils.reduce.abs_max_kernel", "modulename": "linghe.utils.reduce", "qualname": "abs_max_kernel", "kind": "function", "doc": "

    \n", "signature": "(\tx_ptr,\tscale_ptr,\tsmooth_scale_ptr,\toutput_ptr,\tmin_value,\tM,\tN,\tH: int,\tW: int,\tEVEN: int,\tQUANTIZED: int):", "funcdef": "def"}, {"fullname": "linghe.utils.reduce.triton_abs_max", "modulename": "linghe.utils.reduce", "qualname": "triton_abs_max", "kind": "function", "doc": "

    \n", "signature": "(x, scale=None, smooth_scale=None, min_value=1e-30, axis=0):", "funcdef": "def"}, {"fullname": "linghe.utils.reduce.batch_count_zero_kernel", "modulename": "linghe.utils.reduce", "qualname": "batch_count_zero_kernel", "kind": "function", "doc": "

    \n", "signature": "(input_ptrs, size_ptr, count_ptr, B: int):", "funcdef": "def"}, {"fullname": "linghe.utils.reduce.triton_batch_count_zero", "modulename": "linghe.utils.reduce", "qualname": "triton_batch_count_zero", "kind": "function", "doc": "

    \n", "signature": "(xs):", "funcdef": "def"}, {"fullname": "linghe.utils.reduce.batch_sum_with_ord_kernel", "modulename": "linghe.utils.reduce", "qualname": "batch_sum_with_ord_kernel", "kind": "function", "doc": "

    \n", "signature": "(input_ptrs, size_ptr, count_ptr, B: int, ORD: int):", "funcdef": "def"}, {"fullname": "linghe.utils.reduce.triton_batch_sum_with_ord", "modulename": "linghe.utils.reduce", "qualname": "triton_batch_sum_with_ord", "kind": "function", "doc": "

    \n", "signature": "(xs, ord=2):", "funcdef": "def"}, {"fullname": "linghe.utils.rope", "modulename": "linghe.utils.rope", "kind": "module", "doc": "

    Copyright (c) Ant Financial Service Group and its affiliates.

    \n"}, {"fullname": "linghe.utils.rope.half_rope_forward_kernel", "modulename": "linghe.utils.rope", "qualname": "half_rope_forward_kernel", "kind": "function", "doc": "

    \n", "signature": "(\tq_ptr,\tk_ptr,\tfreqs_ptr,\tqo_ptr,\tko_ptr,\tB,\tq_stride,\tk_stride,\tH: int,\th: int,\tD: int,\td: int):", "funcdef": "def"}, {"fullname": "linghe.utils.rope.triton_half_rope_forward", "modulename": "linghe.utils.rope", "qualname": "triton_half_rope_forward", "kind": "function", "doc": "

    \n", "signature": "(q, k, freqs):", "funcdef": "def"}, {"fullname": "linghe.utils.rope.half_rope_backward_kernel", "modulename": "linghe.utils.rope", "qualname": "half_rope_backward_kernel", "kind": "function", "doc": "

    \n", "signature": "(q_ptr, k_ptr, freqs_ptr, B, H: int, h: int, D: int, d: int):", "funcdef": "def"}, {"fullname": "linghe.utils.rope.triton_half_rope_backward", "modulename": "linghe.utils.rope", "qualname": "triton_half_rope_backward", "kind": "function", "doc": "

    \n", "signature": "(q_grad, k_grad, freqs, inplace=False):", "funcdef": "def"}, {"fullname": "linghe.utils.rope.qk_norm_and_half_rope_forward_kernel", "modulename": "linghe.utils.rope", "qualname": "qk_norm_and_half_rope_forward_kernel", "kind": "function", "doc": "

    \n", "signature": "(\tqkv_ptr,\tq_norm_weight_ptr,\tk_norm_weight_ptr,\tfreqs_ptr,\tqo_ptr,\tko_ptr,\tvo_ptr,\tB,\tstride,\teps,\tH: int,\th: int,\tD: int,\td: int,\tinterleave: int):", "funcdef": "def"}, {"fullname": "linghe.utils.rope.triton_qk_norm_and_half_rope_forward", "modulename": "linghe.utils.rope", "qualname": "triton_qk_norm_and_half_rope_forward", "kind": "function", "doc": "

    \n", "signature": "(\tqkv,\tq_norm_weight,\tk_norm_weight,\tfreqs,\tH=32,\th=4,\teps=1e-06,\tinterleave=True,\ttranspose=False):", "funcdef": "def"}, {"fullname": "linghe.utils.rope.qk_norm_and_half_rope_backward_kernel", "modulename": "linghe.utils.rope", "qualname": "qk_norm_and_half_rope_backward_kernel", "kind": "function", "doc": "

    \n", "signature": "(\tgq_ptr,\tgk_ptr,\tgv_ptr,\tqkv_ptr,\tq_norm_weight_ptr,\tk_norm_weight_ptr,\tfreqs_ptr,\tdqkv_ptr,\tdqw_ptr,\tdkw_ptr,\tB,\tstride,\teps,\tH: int,\th: int,\tD: int,\td: int,\tinterleave: int):", "funcdef": "def"}, {"fullname": "linghe.utils.rope.triton_qk_norm_and_half_rope_backward", "modulename": "linghe.utils.rope", "qualname": "triton_qk_norm_and_half_rope_backward", "kind": "function", "doc": "

    \n", "signature": "(\tgq,\tgk,\tgv,\tqkv,\tq_norm_weight,\tk_norm_weight,\tfreqs,\teps=1e-06,\ttranspose=False,\tinterleave=True):", "funcdef": "def"}, {"fullname": "linghe.utils.scatter", "modulename": "linghe.utils.scatter", "kind": "module", "doc": "

    Copyright (c) Ant Financial Service Group and its affiliates.

    \n"}, {"fullname": "linghe.utils.scatter.aligned_scatter_add_kernel", "modulename": "linghe.utils.scatter", "qualname": "aligned_scatter_add_kernel", "kind": "function", "doc": "

    \n", "signature": "(\tx_ptr,\to_ptr,\tindices_ptr,\tweights_ptr,\tM,\tN: int,\tK: int,\tSCALE: int):", "funcdef": "def"}, {"fullname": "linghe.utils.scatter.triton_aligned_scatter_add", "modulename": "linghe.utils.scatter", "qualname": "triton_aligned_scatter_add", "kind": "function", "doc": "

    \n", "signature": "(x, outputs, indices, weights=None):", "funcdef": "def"}, {"fullname": "linghe.utils.scatter.scatter_add_kernel", "modulename": "linghe.utils.scatter", "qualname": "scatter_add_kernel", "kind": "function", "doc": "

    \n", "signature": "(x_ptr, o_ptr, indices_ptr, M, T, N: int):", "funcdef": "def"}, {"fullname": "linghe.utils.scatter.fp32_to_bf16_kernel", "modulename": "linghe.utils.scatter", "qualname": "fp32_to_bf16_kernel", "kind": "function", "doc": "

    \n", "signature": "(x_ptr, o_ptr, M, T, N: int):", "funcdef": "def"}, {"fullname": "linghe.utils.scatter.triton_scatter_add", "modulename": "linghe.utils.scatter", "qualname": "triton_scatter_add", "kind": "function", "doc": "

    \n", "signature": "(x, outputs, indices):", "funcdef": "def"}, {"fullname": "linghe.utils.scatter.unpermute_with_mask_map_kernel", "modulename": "linghe.utils.scatter", "qualname": "unpermute_with_mask_map_kernel", "kind": "function", "doc": "

    \n", "signature": "(\tgrads_ptr,\tprobs_ptr,\tmask_map_ptr,\toutput_ptr,\toutput_probs_ptr,\tnum_experts: int,\tN: int,\tPROB: int):", "funcdef": "def"}, {"fullname": "linghe.utils.scatter.triton_unpermute_with_mask_map", "modulename": "linghe.utils.scatter", "qualname": "triton_unpermute_with_mask_map", "kind": "function", "doc": "

    \n", "signature": "(grad: torch.Tensor, row_id_map: torch.Tensor, probs: torch.Tensor):", "funcdef": "def"}, {"fullname": "linghe.utils.silu", "modulename": "linghe.utils.silu", "kind": "module", "doc": "

    Copyright (c) Ant Financial Service Group and its affiliates.

    \n"}, {"fullname": "linghe.utils.silu.silu_and_block_quant_forward_kernel", "modulename": "linghe.utils.silu", "qualname": "silu_and_block_quant_forward_kernel", "kind": "function", "doc": "

    \n", "signature": "(\tx_ptr,\tout_ptr,\tscale_ptr,\ttranspose_output_ptr,\ttranspose_scale_ptr,\tM,\tn: int,\tROUND: int,\tOUTPUT_MODE: int):", "funcdef": "def"}, {"fullname": "linghe.utils.silu.triton_silu_and_block_quant_forward", "modulename": "linghe.utils.silu", "qualname": "triton_silu_and_block_quant_forward", "kind": "function", "doc": "

    \n", "signature": "(x, out=None, scale=None, round_scale=False, output_mode=2):", "funcdef": "def"}, {"fullname": "linghe.utils.silu.silu_and_block_quant_backward_kernel", "modulename": "linghe.utils.silu", "qualname": "silu_and_block_quant_backward_kernel", "kind": "function", "doc": "

    \n", "signature": "(\tg_ptr,\tx_ptr,\tdx_ptr,\tdx_scale_ptr,\ttranspose_dx_ptr,\ttranspose_dx_scale_ptr,\tM,\tn: int,\tROUND: int):", "funcdef": "def"}, {"fullname": "linghe.utils.silu.triton_silu_and_block_quant_backward", "modulename": "linghe.utils.silu", "qualname": "triton_silu_and_block_quant_backward", "kind": "function", "doc": "

    \n", "signature": "(g, x, round_scale=False):", "funcdef": "def"}, {"fullname": "linghe.utils.silu.batch_weighted_silu_and_block_quant_forward_kernel", "modulename": "linghe.utils.silu", "qualname": "batch_weighted_silu_and_block_quant_forward_kernel", "kind": "function", "doc": "

    \n", "signature": "(\tx_ptr,\tweight_ptr,\tout_ptr,\tscale_ptr,\ttranspose_output_ptr,\ttranspose_scale_ptr,\tcount_ptr,\taccum_ptr,\tn: int,\tE: int,\tROUND: int,\tOUTPUT_MODE: int):", "funcdef": "def"}, {"fullname": "linghe.utils.silu.triton_batch_weighted_silu_and_block_quant_forward", "modulename": "linghe.utils.silu", "qualname": "triton_batch_weighted_silu_and_block_quant_forward", "kind": "function", "doc": "

    \n", "signature": "(\tx,\tweight,\tcounts,\tsplits=None,\tout=None,\tscale=None,\tround_scale=False,\toutput_mode=2):", "funcdef": "def"}, {"fullname": "linghe.utils.silu.batch_weighted_silu_and_block_quant_backward_kernel", "modulename": "linghe.utils.silu", "qualname": "batch_weighted_silu_and_block_quant_backward_kernel", "kind": "function", "doc": "

    \n", "signature": "(\tg_ptr,\tx_ptr,\tweight_ptr,\tcount_ptr,\taccum_ptr,\tdx_ptr,\tdx_scale_ptr,\ttranspose_dx_ptr,\ttranspose_dx_scale_ptr,\tdw_ptr,\tn: int,\tE: int,\tROUND: int):", "funcdef": "def"}, {"fullname": "linghe.utils.silu.triton_batch_weighted_silu_and_block_quant_backward", "modulename": "linghe.utils.silu", "qualname": "triton_batch_weighted_silu_and_block_quant_backward", "kind": "function", "doc": "

    \n", "signature": "(g, x, weight, counts, splits=None, round_scale=False):", "funcdef": "def"}, {"fullname": "linghe.utils.transpose", "modulename": "linghe.utils.transpose", "kind": "module", "doc": "

    Copyright (c) Ant Financial Service Group and its affiliates.

    \n"}, {"fullname": "linghe.utils.transpose.deprecated_transpose_kernel", "modulename": "linghe.utils.transpose", "qualname": "deprecated_transpose_kernel", "kind": "function", "doc": "

    \n", "signature": "(x_ptr, t_ptr, M, N, H: int, W: int, EVEN: int):", "funcdef": "def"}, {"fullname": "linghe.utils.transpose.triton_depracated_transpose", "modulename": "linghe.utils.transpose", "qualname": "triton_depracated_transpose", "kind": "function", "doc": "

    \n", "signature": "(x):", "funcdef": "def"}, {"fullname": "linghe.utils.transpose.transpose_kernel", "modulename": "linghe.utils.transpose", "qualname": "transpose_kernel", "kind": "function", "doc": "

    \n", "signature": "(x_ptr, t_ptr, M, N, H: int, W: int, EVEN: int):", "funcdef": "def"}, {"fullname": "linghe.utils.transpose.transpose_dim_0_1_kernel", "modulename": "linghe.utils.transpose", "qualname": "transpose_dim_0_1_kernel", "kind": "function", "doc": "

    \n", "signature": "(x_ptr, t_ptr, B, M, b_stride, m_stride, N: int):", "funcdef": "def"}, {"fullname": "linghe.utils.transpose.triton_transpose", "modulename": "linghe.utils.transpose", "qualname": "triton_transpose", "kind": "function", "doc": "

    \n", "signature": "(x, dim0=None, dim1=None):", "funcdef": "def"}, {"fullname": "linghe.utils.transpose.transpose_and_pad_kernel", "modulename": "linghe.utils.transpose", "qualname": "transpose_and_pad_kernel", "kind": "function", "doc": "

    \n", "signature": "(x_ptr, t_ptr, M, N, P, H: int, W: int, EVEN: int):", "funcdef": "def"}, {"fullname": "linghe.utils.transpose.triton_transpose_and_pad", "modulename": "linghe.utils.transpose", "qualname": "triton_transpose_and_pad", "kind": "function", "doc": "

    \n", "signature": "(x, out=None, pad=True):", "funcdef": "def"}, {"fullname": "linghe.utils.transpose.batch_transpose_kernel", "modulename": "linghe.utils.transpose", "qualname": "batch_transpose_kernel", "kind": "function", "doc": "

    \n", "signature": "(xs_ptr, xts_ptr, M, N, H: int, W: int):", "funcdef": "def"}, {"fullname": "linghe.utils.transpose.triton_batch_transpose", "modulename": "linghe.utils.transpose", "qualname": "triton_batch_transpose", "kind": "function", "doc": "

    \n", "signature": "(xs, xts=None):", "funcdef": "def"}, {"fullname": "linghe.utils.transpose.batch_transpose_and_pad_kernel", "modulename": "linghe.utils.transpose", "qualname": "batch_transpose_and_pad_kernel", "kind": "function", "doc": "

    \n", "signature": "(x_ptr, t_ptr, count_ptr, accum_ptr, pad_accum_ptr, N, H: int, W: int):", "funcdef": "def"}, {"fullname": "linghe.utils.transpose.triton_batch_transpose_and_pad", "modulename": "linghe.utils.transpose", "qualname": "triton_batch_transpose_and_pad", "kind": "function", "doc": "

    \n", "signature": "(x, count_list, x_t=None, pad=True):", "funcdef": "def"}, {"fullname": "linghe.utils.transpose.configs", "modulename": "linghe.utils.transpose", "qualname": "configs", "kind": "variable", "doc": "

    \n", "default_value": "[<triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>, <triton.Config object>]"}, {"fullname": "linghe.utils.transpose.opt_transpose_kernel", "modulename": "linghe.utils.transpose", "qualname": "opt_transpose_kernel", "kind": "function", "doc": "

    \n", "signature": "(x_ptr, t_ptr, M, N, D, H: int, W: int):", "funcdef": "def"}, {"fullname": "linghe.utils.transpose.triton_opt_transpose", "modulename": "linghe.utils.transpose", "qualname": "triton_opt_transpose", "kind": "function", "doc": "

    \n", "signature": "(x):", "funcdef": "def"}]; + /** pdoc search index */const docs = [{"fullname": "linghe", "modulename": "linghe", "kind": "module", "doc": "

    \n"}, {"fullname": "linghe.facade", "modulename": "linghe.facade", "kind": "module", "doc": "

    \n"}, {"fullname": "linghe.facade.add", "modulename": "linghe.facade.add", "kind": "module", "doc": "

    Copyright (c) Ant Financial Service Group and its affiliates.

    \n"}, {"fullname": "linghe.facade.add.inplace_add", "modulename": "linghe.facade.add", "qualname": "inplace_add", "kind": "function", "doc": "

    inplace add y to x with mix precise

    \n\n
    Arguments:
    \n\n
      \n
    • x: to be updated
    • \n
    • y: add to x
    • \n
    \n\n
    Returns:
    \n\n
    \n

    updated x tensor

    \n
    \n", "signature": "(x: torch.Tensor, y: torch.Tensor):", "funcdef": "def"}, {"fullname": "linghe.facade.fp32_gemm", "modulename": "linghe.facade.fp32_gemm", "kind": "module", "doc": "

    Copyright (c) Ant Financial Service Group and its affiliates.

    \n"}, {"fullname": "linghe.facade.fp32_gemm.fp32_gemm", "modulename": "linghe.facade.fp32_gemm", "qualname": "fp32_gemm", "kind": "function", "doc": "

    gemm with bf16/fp16 inputs and float32 output,\ncurrently used in MoE router gemm.

    \n\n
    Arguments:
    \n\n
      \n
    • input: bf16/fp16 activation tensor
    • \n
    • weight: bf16/fp16 weight tensor
    • \n
    \n\n
    Returns:
    \n\n
    \n

    output of gemm

    \n
    \n", "signature": "(input: torch.Tensor, weight: torch.Tensor):", "funcdef": "def"}, {"fullname": "linghe.facade.hadamard_quant_linear", "modulename": "linghe.facade.hadamard_quant_linear", "kind": "module", "doc": "

    Copyright (c) Ant Financial Service Group and its affiliates.

    \n"}, {"fullname": "linghe.facade.hadamard_quant_linear.HadamardQuantLinear", "modulename": "linghe.facade.hadamard_quant_linear", "qualname": "HadamardQuantLinear", "kind": "class", "doc": "

    a naive implementation of hadamard transformation and quantization

    \n", "bases": "torch.nn.modules.module.Module"}, {"fullname": "linghe.facade.hadamard_quant_linear.HadamardQuantLinear.__init__", "modulename": "linghe.facade.hadamard_quant_linear", "qualname": "HadamardQuantLinear.__init__", "kind": "function", "doc": "
    Arguments:
    \n\n
      \n
    • in_features: in feature number
    • \n
    • out_features: out feature number
    • \n
    • bias: whether use bias
    • \n
    • device: weight device
    • \n
    • dtype: weight dtype
    • \n
    \n", "signature": "(\tin_features: int,\tout_features: int,\tbias: bool = True,\tdevice=None,\tdtype=None)"}, {"fullname": "linghe.facade.loss", "modulename": "linghe.facade.loss", "kind": "module", "doc": "

    Copyright (c) Ant Financial Service Group and its affiliates.

    \n"}, {"fullname": "linghe.facade.loss.softmax_cross_entropy", "modulename": "linghe.facade.loss", "qualname": "softmax_cross_entropy", "kind": "function", "doc": "

    softmax cross entropy

    \n\n
    Arguments:
    \n\n
      \n
    • logits: logits tensor, shape [...,dim]
    • \n
    • labels: labels tensor, shape [...]
    • \n
    • inplace: update gradient in the logits tensor if True
    • \n
    \n\n
    Returns:
    \n\n
    \n

    a tensor of per token loss

    \n
    \n", "signature": "(logits: torch.Tensor, labels: torch.Tensor, inplace: bool = False):", "funcdef": "def"}, {"fullname": "linghe.facade.norm", "modulename": "linghe.facade.norm", "kind": "module", "doc": "

    Copyright (c) Ant Financial Service Group and its affiliates.

    \n"}, {"fullname": "linghe.facade.norm.rms_norm", "modulename": "linghe.facade.norm", "qualname": "rms_norm", "kind": "function", "doc": "

    rms norm of x with weight

    \n\n
    Arguments:
    \n\n
      \n
    • x: activation tensor
    • \n
    • weight: weight tensor
    • \n
    • eps: epsilon for RMS
    • \n
    \n\n
    Returns:
    \n\n
    \n

    rms output

    \n
    \n", "signature": "(x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-06):", "funcdef": "def"}, {"fullname": "linghe.facade.norm.group_norm_gate", "modulename": "linghe.facade.norm", "qualname": "group_norm_gate", "kind": "function", "doc": "

    return group_rms_norm(transpose(attn_output, [0,1]), weight) * sigmoid(gate)

    \n\n
    Arguments:
    \n\n
      \n
    • attn_output: output of core attn, shape [bs, length, n_heads, head_dim]
    • \n
    • gate: gate tensor for attention output, shape [length, bs, dim]
    • \n
    • weight: weight of RMS norm, shape [dim]
    • \n
    • eps: epsilon for RMS
    • \n
    • group_size: group size of group RMS norm
    • \n
    \n\n
    Returns:
    \n\n
    \n

    output with shape [length, bs, dim]

    \n
    \n", "signature": "(\tattn_output: torch.Tensor,\tgate: torch.Tensor,\tweight: torch.Tensor,\teps: float = 1e-06,\tgroup_size: int = 4):", "funcdef": "def"}, {"fullname": "linghe.facade.rope", "modulename": "linghe.facade.rope", "kind": "module", "doc": "

    Copyright (c) Ant Financial Service Group and its affiliates.

    \n"}, {"fullname": "linghe.facade.rope.qk_norm_half_rope", "modulename": "linghe.facade.rope", "qualname": "qk_norm_half_rope", "kind": "function", "doc": "

    split qkv to q/k/v, apply qk norm and half rope to q/k, transpose q/k/v to flash-attention layout

    \n\n
    Arguments:
    \n\n
      \n
    • qkv: QKV tensor with size of [S, B, dim], heads are interleaved
    • \n
    • q_norm_weight: rms norm weight for query
    • \n
    • k_norm_weight: rms norm weight for key
    • \n
    • freqs: Freqs tensor based on half dim.
    • \n
    • H: Number of attention heads.
    • \n
    • h: Number of key/value heads.
    • \n
    • eps: epsilon value for L2 normalization.
    • \n
    \n\n
    Returns:
    \n\n
    \n
      \n
    • qo: shape [B, S, H, head_dim]
    • \n
    • ko: shape [B, S, h, head_dim]
    • \n
    • vo: shape [B, S, h, head_dim]
    • \n
    \n
    \n", "signature": "(\tqkv: torch.Tensor,\tq_norm_weight: torch.Tensor,\tk_norm_weight: torch.Tensor,\tfreqs: torch.Tensor,\tH: int = 32,\th: int = 4,\teps: float = 1e-06):", "funcdef": "def"}, {"fullname": "linghe.facade.smooth_quant_linear", "modulename": "linghe.facade.smooth_quant_linear", "kind": "module", "doc": "

    Copyright (c) Ant Financial Service Group and its affiliates.

    \n"}, {"fullname": "linghe.facade.smooth_quant_linear.SmoothQuantLinear", "modulename": "linghe.facade.smooth_quant_linear", "qualname": "SmoothQuantLinear", "kind": "class", "doc": "

    a naive implementation of smooth quantization linear

    \n", "bases": "torch.nn.modules.module.Module"}, {"fullname": "linghe.facade.smooth_quant_linear.SmoothQuantLinear.__init__", "modulename": "linghe.facade.smooth_quant_linear", "qualname": "SmoothQuantLinear.__init__", "kind": "function", "doc": "
    Arguments:
    \n\n
      \n
    • in_features: in feature number
    • \n
    • out_features: out feature number
    • \n
    • bias: whether use bias
    • \n
    • device: weight device
    • \n
    • dtype: weight dtype
    • \n
    \n", "signature": "(\tin_features: int,\tout_features: int,\tbias: bool = True,\tdevice=None,\tdtype=None)"}, {"fullname": "linghe.facade.transpose", "modulename": "linghe.facade.transpose", "kind": "module", "doc": "

    Copyright (c) Ant Financial Service Group and its affiliates.

    \n"}, {"fullname": "linghe.facade.transpose.transpose_dim01", "modulename": "linghe.facade.transpose", "qualname": "transpose_dim01", "kind": "function", "doc": "

    transpose a tensor with the first two dims, x.ndims should not greater than 4

    \n\n
    Arguments:
    \n\n
      \n
    • x: input tensor
    • \n
    \n\n
    Returns:
    \n\n
    \n

    a transposed tensor

    \n
    \n", "signature": "(x):", "funcdef": "def"}, {"fullname": "linghe.gemm", "modulename": "linghe.gemm", "kind": "module", "doc": "

    \n"}, {"fullname": "linghe.gemm.blockwise_fp8_gemm", "modulename": "linghe.gemm.blockwise_fp8_gemm", "kind": "module", "doc": "

    Copyright (c) Ant Financial Service Group and its affiliates.

    \n"}, {"fullname": "linghe.gemm.channelwise_fp8_gemm", "modulename": "linghe.gemm.channelwise_fp8_gemm", "kind": "module", "doc": "

    Copyright (c) Ant Financial Service Group and its affiliates.

    \n"}, {"fullname": "linghe.gemm.channelwise_fp8_gemm.triton_scaled_mm", "modulename": "linghe.gemm.channelwise_fp8_gemm", "qualname": "triton_scaled_mm", "kind": "function", "doc": "

    similar to torch._scaled_mm, support accumulating gemm output to c\n and low precision output tensor

    \n\n
    Arguments:
    \n\n
      \n
    • a: left fp8 tensor
    • \n
    • b: right fp8 tensor, column-major
    • \n
    • a_scale: fp32 scale of a
    • \n
    • b_scale: fp32 scale of b
    • \n
    • out_dtype: output tensor dtype
    • \n
    • c: output tensor
    • \n
    • accum: accumulate output on c if True
    • \n
    \n\n
    Returns:
    \n\n
    \n

    c: output tensor

    \n
    \n", "signature": "(\ta: torch.Tensor,\tb: torch.Tensor,\ta_scale: torch.Tensor,\tb_scale: torch.Tensor,\tout_dtype=torch.float32,\tc=None,\taccum=True):", "funcdef": "def"}, {"fullname": "linghe.gemm.fp32_gemm", "modulename": "linghe.gemm.fp32_gemm", "kind": "module", "doc": "

    Copyright (c) Ant Financial Service Group and its affiliates.

    \n"}, {"fullname": "linghe.gemm.fp32_gemm.triton_fp32_gemm", "modulename": "linghe.gemm.fp32_gemm", "qualname": "triton_fp32_gemm", "kind": "function", "doc": "

    return fp32 gemm result with fp16/bf16 inputs,\n it's mainly used for MoE router GEMM\n and DO NOT suitable for large size GEMM

    \n\n
    Arguments:
    \n\n
      \n
    • a: left matrix with fp16/bf16 precision
    • \n
    • b: right matrix with fp16/bf16 precision
    • \n
    \n\n
    Returns:
    \n\n
    \n

    c: output with fp32 precision

    \n
    \n", "signature": "(a: torch.Tensor, b: torch.Tensor):", "funcdef": "def"}, {"fullname": "linghe.gemm.fp32_gemm.triton_fp32_gemm_for_backward", "modulename": "linghe.gemm.fp32_gemm", "qualname": "triton_fp32_gemm_for_backward", "kind": "function", "doc": "

    mix precision gemm for backward, a@b.float()

    \n\n
    Arguments:
    \n\n
      \n
    • a: input gradient, fp32
    • \n
    • b: gemm weight, bf16/fp16
    • \n
    \n\n
    Returns:
    \n\n
    \n

    c: gradient of activation

    \n
    \n", "signature": "(a: torch.Tensor, b: torch.Tensor):", "funcdef": "def"}, {"fullname": "linghe.gemm.fp32_gemm.triton_fp32_gemm_for_update", "modulename": "linghe.gemm.fp32_gemm", "qualname": "triton_fp32_gemm_for_update", "kind": "function", "doc": "

    mix precision gemm for updaing weight

    \n\n
    Arguments:
    \n\n
      \n
    • a: gradient of output, fp32
    • \n
    • b: input activation, bf16/fp16
    • \n
    \n\n
    Returns:
    \n\n
    \n

    c: gradient of weight

    \n
    \n", "signature": "(a: torch.Tensor, b: torch.Tensor):", "funcdef": "def"}, {"fullname": "linghe.gemm.fp32_gemm.triton_scaled_fp32_gemm", "modulename": "linghe.gemm.fp32_gemm", "qualname": "triton_scaled_fp32_gemm", "kind": "function", "doc": "

    c = (ascale[:,None])b\nthis kernel is used to fuse RMSNorm and quantization in MoE layer\nnative implementation:\n y = rms_norm(x),\n y_q = quantization(y),\n router_logits = y@w\nwe can not fuse rms_norm and quantization\nas we still need bf16 y for moe router gemm\nfused implementation:\n y_q, rms = quantization(rms_norm(x))\n router_logits = (x/rms)@y\nso we need a scaled fp32 gemm kernel

    \n\n
    Arguments:
    \n\n
      \n
    • a: activation tensor
    • \n
    • b: weight tensor
    • \n
    • scale: scale for activation tensor, 1/rms
    • \n
    \n\n
    Returns:
    \n\n
    \n

    output tensor

    \n
    \n", "signature": "(a: torch.Tensor, b: torch.Tensor, scale: torch.Tensor):", "funcdef": "def"}, {"fullname": "linghe.gemm.fp32_gemm.triton_scaled_fp32_gemm_for_update", "modulename": "linghe.gemm.fp32_gemm", "qualname": "triton_scaled_fp32_gemm_for_update", "kind": "function", "doc": "

    see triton_scaled_fp32_gemm

    \n\n
    Arguments:
    \n\n
      \n
    • a: y
    • \n
    • b: activation before RMS norm
    • \n
    • scale: 1/rms
    • \n
    \n\n
    Returns:
    \n\n
    \n

    dw

    \n
    \n", "signature": "(a: torch.Tensor, b: torch.Tensor, scale: torch.Tensor):", "funcdef": "def"}, {"fullname": "linghe.quant", "modulename": "linghe.quant", "kind": "module", "doc": "

    \n"}, {"fullname": "linghe.quant.block", "modulename": "linghe.quant.block", "kind": "module", "doc": "

    Copyright (c) Ant Financial Service Group and its affiliates.

    \n"}, {"fullname": "linghe.quant.block.triton_block_quant", "modulename": "linghe.quant.block", "qualname": "triton_block_quant", "kind": "function", "doc": "

    blockwise quantize x

    \n\n
    Arguments:
    \n\n
      \n
    • x: input tensor
    • \n
    • block_size: block wise
    • \n
    • round_scale: whether round scale to power of 2
    • \n
    \n\n
    Returns:
    \n\n
    \n
      \n
    • y: quantized tensor, float8_e4m3fn
    • \n
    • s: quantization scale, float32
    • \n
    \n
    \n", "signature": "(x, block_size=128, round_scale=False):", "funcdef": "def"}, {"fullname": "linghe.quant.channel", "modulename": "linghe.quant.channel", "kind": "module", "doc": "

    Copyright (c) Ant Financial Service Group and its affiliates.

    \n"}, {"fullname": "linghe.quant.channel.triton_row_quant", "modulename": "linghe.quant.channel", "qualname": "triton_row_quant", "kind": "function", "doc": "

    rowwise quantize x

    \n\n
    Arguments:
    \n\n
      \n
    • x: input x
    • \n
    • round_scale: whether round scale to power of 2
    • \n
    \n\n
    Returns:
    \n\n
    \n

    x_q: quantized tensor\n x_scale: quantization scale

    \n
    \n", "signature": "(x, round_scale=False):", "funcdef": "def"}, {"fullname": "linghe.quant.channel.triton_tokenwise_row_quant", "modulename": "linghe.quant.channel", "qualname": "triton_tokenwise_row_quant", "kind": "function", "doc": "

    rowwise quantize x with power of 2 dim size

    \n\n
    Arguments:
    \n\n
      \n
    • x: input x
    • \n
    • round_scale: whether round scale to power of 2
    • \n
    \n\n
    Returns:
    \n\n
    \n

    out: quantized tensor\n scale: quantization scale

    \n
    \n", "signature": "(x, out=None, scale=None, round_scale=False):", "funcdef": "def"}, {"fullname": "linghe.quant.channel.triton_transpose_row_quant", "modulename": "linghe.quant.channel", "qualname": "triton_transpose_row_quant", "kind": "function", "doc": "

    transpose x and row quantize x

    \n\n
    Arguments:
    \n\n
      \n
    • x: input x
    • \n
    • round_scale: whether round scale to power of 2
    • \n
    \n\n
    Returns:
    \n\n
    \n
      \n
    • x_q: quantized tensor
    • \n
    • x_scale: quantization scale
    • \n
    \n
    \n", "signature": "(x, round_scale=False):", "funcdef": "def"}, {"fullname": "linghe.quant.group", "modulename": "linghe.quant.group", "kind": "module", "doc": "

    Copyright (c) Ant Financial Service Group and its affiliates.

    \n"}, {"fullname": "linghe.quant.group.triton_group_quant", "modulename": "linghe.quant.group", "qualname": "triton_group_quant", "kind": "function", "doc": "

    groupwise quantize x, group is in under rowwise format

    \n\n
    Arguments:
    \n\n
      \n
    • x: input tensor
    • \n
    • group_size: group wise
    • \n
    • round_scale: whether round scale to power of 2
    • \n
    \n\n
    Returns:
    \n\n
    \n
      \n
    • y: quantized tensor, float8_e4m3fn
    • \n
    • s: quantization scale, float32
    • \n
    \n
    \n", "signature": "(x, dtype=torch.float8_e4m3fn, group_size=128, round_scale=False):", "funcdef": "def"}, {"fullname": "linghe.quant.hadamard", "modulename": "linghe.quant.hadamard", "kind": "module", "doc": "

    Copyright (c) Ant Financial Service Group and its affiliates.

    \n"}, {"fullname": "linghe.quant.hadamard.triton_hadamard_quant", "modulename": "linghe.quant.hadamard", "qualname": "triton_hadamard_quant", "kind": "function", "doc": "

    apply hadamard transformation and then quantize transformed tensor

    \n\n
    Arguments:
    \n\n
      \n
    • x: input tensor
    • \n
    • hm: hamadard matrix
    • \n
    \n\n
    Returns:
    \n\n
    \n
      \n
    • x_q: rowwise quantized tensor of non-transposed x
    • \n
    • x_scale: rowwise quantization scale of non-transposed x
    • \n
    • xt_q: columnwise quantized tensor of transposed x
    • \n
    • xt_scale: columnwise quantization scale of transposed x
    • \n
    \n
    \n", "signature": "(x, hm):", "funcdef": "def"}, {"fullname": "linghe.quant.smooth", "modulename": "linghe.quant.smooth", "kind": "module", "doc": "

    Copyright (c) Ant Financial Service Group and its affiliates.

    \n"}, {"fullname": "linghe.utils", "modulename": "linghe.utils", "kind": "module", "doc": "

    \n"}, {"fullname": "linghe.utils.add", "modulename": "linghe.utils.add", "kind": "module", "doc": "

    Copyright (c) Ant Financial Service Group and its affiliates.

    \n"}, {"fullname": "linghe.utils.add.triton_inplace_add", "modulename": "linghe.utils.add", "qualname": "triton_inplace_add", "kind": "function", "doc": "

    inplace add y to x

    \n\n
    Arguments:
    \n\n
      \n
    • x: Tensor
    • \n
    • y: Tensor
    • \n
    • accum: x += y if accum=True else x.copy_(y)
    • \n
    \n\n
    Returns:
    \n\n
    \n

    updated x

    \n
    \n", "signature": "(x: torch.Tensor, y: torch.Tensor, accum: bool = True):", "funcdef": "def"}, {"fullname": "linghe.utils.dot", "modulename": "linghe.utils.dot", "kind": "module", "doc": "

    Copyright (c) Ant Financial Service Group and its affiliates.

    \n"}, {"fullname": "linghe.utils.dot.triton_dot", "modulename": "linghe.utils.dot", "qualname": "triton_dot", "kind": "function", "doc": "

    vector dot multiply, output = sum(x*y, 1),\nit is used to calculate gradient of router weight

    \n\n
    Arguments:
    \n\n
      \n
    • x:
    • \n
    • y:
    • \n
    \n\n
    Returns:
    \n\n
    \n

    output of sum(x*y, 1)

    \n
    \n", "signature": "(x, y):", "funcdef": "def"}, {"fullname": "linghe.utils.gather", "modulename": "linghe.utils.gather", "kind": "module", "doc": "

    Copyright (c) Ant Financial Service Group and its affiliates.

    \n"}, {"fullname": "linghe.utils.gather.triton_make_row_id_map", "modulename": "linghe.utils.gather", "qualname": "triton_make_row_id_map", "kind": "function", "doc": "

    make row id map, values in the tensor are the row indices

    \n\n
    Arguments:
    \n\n
      \n
    • routing_map: a tensor of 0/1 values, 1 indicates routed
    • \n
    • multiple_of: padding the tokens of each expert to multiple of this value
    • \n
    \n\n
    Returns:
    \n\n
    \n

    row id map with shape [n_tokens, n_experts]

    \n
    \n", "signature": "(routing_map: torch.Tensor, multiple_of: int = 1):", "funcdef": "def"}, {"fullname": "linghe.utils.gather.triton_make_row_id_map_and_indices", "modulename": "linghe.utils.gather", "qualname": "triton_make_row_id_map_and_indices", "kind": "function", "doc": "

    similar with triton_make_row_id_map, but output an indices tensor as well

    \n\n
    Arguments:
    \n\n
      \n
    • routing_map: [n_tokens, n_experts]
    • \n
    • num_out_tokens: sum(round_up_to(n_tokens, multiple_of))
    • \n
    • multiple_of: padding the tokens of each expert to this value
    • \n
    \n\n
    Returns:
    \n\n
    \n

    row_in_map: [n_tokens, n_experts]\n row_indices: [num_out_tokens]

    \n
    \n", "signature": "(routing_map: torch.Tensor, num_out_tokens: int, multiple_of: int = 1):", "funcdef": "def"}, {"fullname": "linghe.utils.gather.triton_index_select", "modulename": "linghe.utils.gather", "qualname": "triton_index_select", "kind": "function", "doc": "

    index select for quantized tensor

    \n\n
    Arguments:
    \n\n
      \n
    • x: [bs, dim]
    • \n
    • indices: [K]
    • \n
    • scale: [bs]
    • \n
    \n\n
    Returns:
    \n\n
    \n

    out: output of selected x\n scale_out: scale of selected scale

    \n
    \n", "signature": "(x, indices, scale=None, out=None, scale_out=None):", "funcdef": "def"}, {"fullname": "linghe.utils.gather.triton_permute_with_mask_map", "modulename": "linghe.utils.gather", "qualname": "triton_permute_with_mask_map", "kind": "function", "doc": "

    gather quantized tensor with row id map

    \n\n
    Arguments:
    \n\n
      \n
    • inp: [num_tokens, hidden_size], rowwise quantized tensor
    • \n
    • scale: [num_tokens], quantization scale
    • \n
    • probs: router prob, used as weight
    • \n
    • row_id_map: [n_experts, num_tokens]\nindex >= 0: row index of output tensor\nindex == -1: ignore\nNote: index may not be contiguous
    • \n
    • num_out_tokens: output token count, including padding tokens
    • \n
    • contiguous: whether indices in row_id_map is contiguous,\nFalse means padded
    • \n
    • tokens_per_expert: [num_experts], token count per expert,\nnon-blocking cuda tensor
    • \n
    \n\n
    Returns:
    \n\n
    \n

    output: permuted quantized tensor\n permuted_scale: permuted quantization scale\n permuted_probs: permuted router prob

    \n
    \n", "signature": "(\tinp: torch.Tensor,\tscale: torch.Tensor,\tprobs: torch.Tensor,\trow_id_map: torch.Tensor,\tnum_out_tokens: int,\tcontiguous: bool = True,\ttokens_per_expert: Optional[torch.Tensor] = None):", "funcdef": "def"}, {"fullname": "linghe.utils.gather.triton_batch_transpose_smooth_permute_with_indices", "modulename": "linghe.utils.gather", "qualname": "triton_batch_transpose_smooth_permute_with_indices", "kind": "function", "doc": "

    used for smooth quantization backward in megatron 0.12,\nx is gathered, requantized, padded to multiple of 32 and tranposed

    \n\n
    Arguments:
    \n\n
      \n
    • x: dy, [bs, dim], it is smooth quantized
    • \n
    • scale: [bs], quantized scale
    • \n
    • org_smooth_scale: [dim]
    • \n
    • smooth_scales: [n_experts, dim]
    • \n
    • indices: [sum(tokens_per_experts)]
    • \n
    • token_count_per_expert: [n_experts], tensor of token count per expert
    • \n
    • splits: [n_experts], list of token_count_per_expert
    • \n
    • round_scale: round quantization scale to power of 2
    • \n
    \n\n
    Returns:
    \n\n
    \n

    x_q: [sum(roundup(tokens_per_experts)) * dim]\n x_scale: [sum(roundup(tokens_per_experts))]

    \n
    \n", "signature": "(\tx,\tscale,\torg_smooth_scale,\tsmooth_scales,\tindices,\ttoken_count_per_expert,\tsplits,\tx_q=None,\tx_scale=None,\tround_scale=False):", "funcdef": "def"}, {"fullname": "linghe.utils.gather.triton_smooth_weighted_permute_with_indices", "modulename": "linghe.utils.gather", "qualname": "triton_smooth_weighted_permute_with_indices", "kind": "function", "doc": "

    select and smooth and quant, used in megatron 0.11 all2all moe

    \n\n
    Arguments:
    \n\n
      \n
    • grads: [bs, dim]
    • \n
    • tokens: [bs, dim]
    • \n
    • smooth_scales: [n_experts, dim]
    • \n
    • token_count_per_expert: [n_experts]
    • \n
    • indices: [n_experts*topk]
    • \n
    • reverse: whether scale is 1/scale
    • \n
    • round_scale: whether round scale to power of 2
    • \n
    \n\n
    Returns:
    \n\n
    \n

    x_q: [bs*topk, dim]\n x_scale: [bstopk]\n x_sum: [bstopk]

    \n
    \n", "signature": "(\tgrads,\ttokens,\tsmooth_scales,\ttoken_count_per_expert,\tindices,\tx_q=None,\tx_scale=None,\tx_sum=None,\treverse=False,\tround_scale=False):", "funcdef": "def"}, {"fullname": "linghe.utils.gather.triton_smooth_permute_with_indices", "modulename": "linghe.utils.gather", "qualname": "triton_smooth_permute_with_indices", "kind": "function", "doc": "

    select and smooth and quant

    \n\n
    Arguments:
    \n\n
      \n
    • grad_data: [bs, dim]
    • \n
    • grad_scale: [bs]
    • \n
    • smooth_scales: [n_experts, dim]
    • \n
    • token_count_per_expert: [n_experts]
    • \n
    • indices: [n_experts*topk]
    • \n
    • x_q: [bs*topk, dim]
    • \n
    • x_scale: [bs*topk]
    • \n
    • reverse:
    • \n
    • round_scale:
    • \n
    \n\n

    Returns:

    \n", "signature": "(\tgrad_data,\tgrad_scale,\tsmooth_scales,\ttoken_count_per_expert,\tindices,\tx_q=None,\tx_scale=None,\treverse=False,\tround_scale=False):", "funcdef": "def"}, {"fullname": "linghe.utils.gather.triton_smooth_permute_with_mask_map", "modulename": "linghe.utils.gather", "qualname": "triton_smooth_permute_with_mask_map", "kind": "function", "doc": "

    gather ( and optional dequant) and smooth quant

    \n\n
    Arguments:
    \n\n
      \n
    • inp: [num_tokens, hidden_size], rowwise quantized tensor
    • \n
    • row_id_map: [n_experts, num_tokens], indices
    • \n
    • scale: [num_tokens, hs], rowwise_scale_inv, optional
    • \n
    • num_tokens: [n_experts]
    • \n
    • num_experts:
    • \n
    • num_out_tokens:
    • \n
    • hidden_size:
    • \n
    • smooth_scales: [n_experts, hidden_size]
    • \n
    • reverse:
    • \n
    • round_scale:
    • \n
    \n\n
    Returns:
    \n\n
    \n
      \n
    • output: output tensor
    • \n
    • permuted_scale: permuted scale if scale is not None
    • \n
    \n
    \n", "signature": "(\tinp: torch.Tensor,\trow_id_map: torch.Tensor,\tscale: torch.Tensor,\tnum_tokens: int,\tnum_experts: int,\tnum_out_tokens: int,\thidden_size: int,\tsmooth_scales: torch.Tensor,\treverse=True,\tround_scale=False):", "funcdef": "def"}, {"fullname": "linghe.utils.loss", "modulename": "linghe.utils.loss", "kind": "module", "doc": "

    Copyright (c) Ant Financial Service Group and its affiliates.

    \n"}, {"fullname": "linghe.utils.loss.triton_softmax_cross_entropy_forward", "modulename": "linghe.utils.loss", "qualname": "triton_softmax_cross_entropy_forward", "kind": "function", "doc": "

    compute token-wise softmax cross entropy loss

    \n\n
    Arguments:
    \n\n
      \n
    • logits: logits tensor
    • \n
    • labels: labels tensor
    • \n
    \n\n
    Returns:
    \n\n
    \n

    loss of each token

    \n
    \n", "signature": "(logits, labels):", "funcdef": "def"}, {"fullname": "linghe.utils.loss.triton_softmax_cross_entropy_backward", "modulename": "linghe.utils.loss", "qualname": "triton_softmax_cross_entropy_backward", "kind": "function", "doc": "

    backward of softmax cross entropy loss

    \n\n
    Arguments:
    \n\n
      \n
    • logits: logit tensor, [bs, dim]
    • \n
    • labels: label tensor, [bs]
    • \n
    • sum_exp: [bs]
    • \n
    • max_logit: [bs]
    • \n
    • input_grad: gradient, [bs, dim]
    • \n
    \n\n
    Returns:
    \n\n
    \n

    output_grad: [bs, dim]

    \n
    \n", "signature": "(logits, labels, sum_exp, max_logit, input_grad, output_grad=None):", "funcdef": "def"}, {"fullname": "linghe.utils.norm", "modulename": "linghe.utils.norm", "kind": "module", "doc": "

    \n"}, {"fullname": "linghe.utils.norm.triton_rms_norm_forward", "modulename": "linghe.utils.norm", "qualname": "triton_rms_norm_forward", "kind": "function", "doc": "

    rms norm

    \n\n
    Arguments:
    \n\n
      \n
    • x: input tensor
    • \n
    • weight: weight of rms norm
    • \n
    • eps: epsilon of rms norm
    • \n
    \n\n
    Returns:
    \n\n
    \n

    out: output tensor

    \n
    \n", "signature": "(x, weight, eps=1e-06, out=None):", "funcdef": "def"}, {"fullname": "linghe.utils.norm.triton_rms_norm_and_block_quant_forward", "modulename": "linghe.utils.norm", "qualname": "triton_rms_norm_and_block_quant_forward", "kind": "function", "doc": "

    Fused RMSNorm forward and block quantization.

    \n\n
    Arguments:
    \n\n
      \n
    • x: Input tensor, shape [M, N]
    • \n
    • weight: RMSNorm weight, shape [N]
    • \n
    • eps: epsilon value for L2 normalization.
    • \n
    • out: output of quantization data
    • \n
    • scale: output of quantization scale.
    • \n
    • rms: output of rms
    • \n
    • round_scale: Set whether to force power of 2 scales.
    • \n
    • output_mode: one of {0, 1, 2}.\n0: only output non-transpose tensor\n1: only output transposed tensor\n2: return both
    • \n
    \n\n
    Returns:
    \n\n
    \n
      \n
    • out: quantization data.
    • \n
    • scale: quantization scale.
    • \n
    • rms: Reciprocal of the root mean square of the\n input calculated over the last dimension.
    • \n
    • transpose_output: quantization data of transposed gradient.
    • \n
    • transpose_scale: quantization scale of transposed gradient.
    • \n
    \n
    \n", "signature": "(\tx: torch.Tensor,\tweight: torch.Tensor,\teps: float = 1e-06,\tout: Optional[torch.Tensor] = None,\tscale: Optional[torch.Tensor] = None,\trms: Optional[torch.Tensor] = None,\tround_scale: bool = False,\toutput_mode: int = 2):", "funcdef": "def"}, {"fullname": "linghe.utils.norm.triton_group_norm_gate_forward", "modulename": "linghe.utils.norm", "qualname": "triton_group_norm_gate_forward", "kind": "function", "doc": "

    norm and gate in linear attention

    \n\n
    Arguments:
    \n\n
      \n
    • x: output of attn, [bs, length, n_heads, head_dim]
    • \n
    • gate: gate tensor, [length, bs, dim]
    • \n
    • weight: rms norm weight, [dim]
    • \n
    • eps: epsilon of rms norm
    • \n
    • group_size: group size of group rms norm
    • \n
    \n\n
    Returns:
    \n\n
    \n

    output tensor

    \n
    \n", "signature": "(x: torch.Tensor, gate, weight, eps=1e-06, group_size=4):", "funcdef": "def"}, {"fullname": "linghe.utils.rearange", "modulename": "linghe.utils.rearange", "kind": "module", "doc": "

    Copyright (c) Ant Financial Service Group and its affiliates.

    \n"}, {"fullname": "linghe.utils.rearange.triton_split_and_cat", "modulename": "linghe.utils.rearange", "qualname": "triton_split_and_cat", "kind": "function", "doc": "

    split x to multiple tensors and cat with indices,\nit is used for permutation in moe

    \n\n
    Arguments:
    \n\n
      \n
    • x: [bs, dim]
    • \n
    • counts: [n_split]
    • \n
    • indices: [n_split]
    • \n
    • scales: [bs]
    • \n
    \n\n
    Returns:
    \n\n
    \n
      \n
    • y: output tensor
    • \n
    • output_scales: output scales if scales is not None
    • \n
    \n
    \n", "signature": "(x, counts, indices, scales=None):", "funcdef": "def"}, {"fullname": "linghe.utils.reduce", "modulename": "linghe.utils.reduce", "kind": "module", "doc": "

    Copyright (c) Ant Financial Service Group and its affiliates.

    \n"}, {"fullname": "linghe.utils.reduce.triton_abs_max", "modulename": "linghe.utils.reduce", "qualname": "triton_abs_max", "kind": "function", "doc": "

    columnwise abs max of x, it is used in smooth quantization

    \n\n
    Arguments:
    \n\n
      \n
    • x: input tensor, may be quantized tensor
    • \n
    • scale: quantization scale if x is quantized
    • \n
    • smooth_scale: optional smooth scale
    • \n
    • min_value: output = max(max(abs(x,0)), min_value)
    • \n
    • axis: reduce axis
    • \n
    \n\n
    Returns:
    \n\n
    \n

    max tensor

    \n
    \n", "signature": "(x, scale=None, smooth_scale=None, min_value=1e-30, axis=0):", "funcdef": "def"}, {"fullname": "linghe.utils.reduce.triton_batch_count_zero", "modulename": "linghe.utils.reduce", "qualname": "triton_batch_count_zero", "kind": "function", "doc": "

    count zero in tensor list, it is used to monitor zeros in gradient tensor

    \n\n
    Arguments:
    \n\n
      \n
    • xs: input tensors
    • \n
    \n\n
    Returns:
    \n\n
    \n

    a single-value int64 tensor

    \n
    \n", "signature": "(xs):", "funcdef": "def"}, {"fullname": "linghe.utils.reduce.triton_batch_sum_with_ord", "modulename": "linghe.utils.reduce", "qualname": "triton_batch_sum_with_ord", "kind": "function", "doc": "

    return sum(abs(x)**ord).

    \n\n
    Arguments:
    \n\n
      \n
    • xs: Tensor lists.
    • \n
    • ord: the order of tensor.
    • \n
    \n\n
    Returns:
    \n\n
    \n

    a single-value fp32 tensor

    \n
    \n", "signature": "(xs, ord=2):", "funcdef": "def"}, {"fullname": "linghe.utils.rope", "modulename": "linghe.utils.rope", "kind": "module", "doc": "

    Copyright (c) Ant Financial Service Group and its affiliates.

    \n"}, {"fullname": "linghe.utils.rope.triton_half_rope_forward", "modulename": "linghe.utils.rope", "qualname": "triton_half_rope_forward", "kind": "function", "doc": "

    apply norm to qk, then apply half rope to qk

    \n\n
    Arguments:
    \n\n
      \n
    • q: query tensor, [len, bs, q_head, head_dim]
    • \n
    • k: key tensor, [len, bs, kv_head, head_dim]
    • \n
    • freqs: rope freqs
    • \n
    \n\n
    Returns:
    \n\n
    \n
      \n
    • qo: query output
    • \n
    • ko: key output
    • \n
    \n
    \n", "signature": "(q, k, freqs):", "funcdef": "def"}, {"fullname": "linghe.utils.rope.triton_qk_norm_and_half_rope_forward", "modulename": "linghe.utils.rope", "qualname": "triton_qk_norm_and_half_rope_forward", "kind": "function", "doc": "

    split qkv to q/k/v, apply qk norm and half rope to q/k,\n transpose q/k/v to flash-attention layout

    \n\n
    Arguments:
    \n\n
      \n
    • qkv: QKV tensor with size of [S, B, dim], heads are interleaved
    • \n
    • q_norm_weight: rms norm weight for query
    • \n
    • k_norm_weight: rms norm weight for key
    • \n
    • freqs: Freqs tensor based on half dim.
    • \n
    • H: Number of attention heads.
    • \n
    • h: Number of key/value heads.
    • \n
    • eps: epsilon value for L2 normalization.
    • \n
    • interleave: whether head of qkv is interleaved, i.e., [qqkvqqkv]
    • \n
    • transpose: whether qkv is tranposed, i.e., [S, B, dim],\nonly support transpose format currently
    • \n
    \n\n
    Returns:
    \n\n
    \n
      \n
    • qo: shape [B, S, H, head_dim]
    • \n
    • ko: shape [B, S, h, head_dim]
    • \n
    • vo: shape [B, S, h, head_dim]
    • \n
    \n
    \n", "signature": "(\tqkv,\tq_norm_weight,\tk_norm_weight,\tfreqs,\tH=32,\th=4,\teps=1e-06,\tinterleave=True,\ttranspose=False):", "funcdef": "def"}, {"fullname": "linghe.utils.rope.triton_qk_norm_and_half_rope_backward", "modulename": "linghe.utils.rope", "qualname": "triton_qk_norm_and_half_rope_backward", "kind": "function", "doc": "

    backward kernel of triton_qk_norm_and_half_rope_forward

    \n\n
    Arguments:
    \n\n
      \n
    • gq: gradient of qo, [len, bs, q_head, head_dim]
    • \n
    • gk: gradient of ko, [len, bs, q_head, head_dim]
    • \n
    • gv: gradient of vo, [len, bs, q_head, head_dim]
    • \n
    • qkv: input qkv
    • \n
    • q_norm_weight:
    • \n
    • k_norm_weight:
    • \n
    • freqs:
    • \n
    • eps:
    • \n
    • transpose:
    • \n
    • interleave:
    • \n
    \n\n
    Returns:
    \n\n
    \n
      \n
    • dqkv: gradient of qkv
    • \n
    • dqw: gradient of q_norm_weight
    • \n
    • dkw: gradient of k_norm_weight
    • \n
    \n
    \n", "signature": "(\tgq,\tgk,\tgv,\tqkv,\tq_norm_weight,\tk_norm_weight,\tfreqs,\teps=1e-06,\ttranspose=False,\tinterleave=True):", "funcdef": "def"}, {"fullname": "linghe.utils.scatter", "modulename": "linghe.utils.scatter", "kind": "module", "doc": "

    Copyright (c) Ant Financial Service Group and its affiliates.

    \n"}, {"fullname": "linghe.utils.scatter.triton_aligned_scatter_add", "modulename": "linghe.utils.scatter", "qualname": "triton_aligned_scatter_add", "kind": "function", "doc": "

    scatter_add for megatron 0.11

    \n\n
    Arguments:
    \n\n
      \n
    • x: input tensor
    • \n
    • outputs: output tensor
    • \n
    • indices: gather indices
    • \n
    • weights: rowwise weight, it is router prob in MoE router
    • \n
    \n\n
    Returns:
    \n\n
    \n

    output tensor

    \n
    \n", "signature": "(\tx: torch.Tensor,\toutputs: torch.Tensor,\tindices: torch.Tensor,\tweights: Optional[torch.Tensor] = None):", "funcdef": "def"}, {"fullname": "linghe.utils.scatter.triton_scatter_add", "modulename": "linghe.utils.scatter", "qualname": "triton_scatter_add", "kind": "function", "doc": "

    naive version of scatter add, very slow

    \n\n
    Arguments:
    \n\n
      \n
    • x: input tensor
    • \n
    • outputs: output tensor
    • \n
    • indices: indices
    • \n
    \n\n
    Returns:
    \n\n
    \n

    output tensor

    \n
    \n", "signature": "(x, outputs, indices):", "funcdef": "def"}, {"fullname": "linghe.utils.scatter.triton_unpermute_with_mask_map", "modulename": "linghe.utils.scatter", "qualname": "triton_unpermute_with_mask_map", "kind": "function", "doc": "

    scatter add with row id map

    \n\n
    Arguments:
    \n\n
      \n
    • grad: gradient tensor, [num_out_tokens, hidden_size]
    • \n
    • row_id_map: row id map, [n_experts, num_tokens]
    • \n
    • probs: [num_out_tokens]
    • \n
    \n\n
    Returns:
    \n\n
    \n
      \n
    • output: [num_tokens, hidden_size]
    • \n
    • restore_probs: [num_tokens, num_experts]
    • \n
    \n
    \n", "signature": "(grad: torch.Tensor, row_id_map: torch.Tensor, probs: torch.Tensor):", "funcdef": "def"}, {"fullname": "linghe.utils.silu", "modulename": "linghe.utils.silu", "kind": "module", "doc": "

    Copyright (c) Ant Financial Service Group and its affiliates.

    \n"}, {"fullname": "linghe.utils.silu.triton_weighted_silu_forward", "modulename": "linghe.utils.silu", "qualname": "triton_weighted_silu_forward", "kind": "function", "doc": "

    compute silu(x)*weight, used in bf16/fp16 training with MoE

    \n\n
    Arguments:
    \n\n
      \n
    • x: input tensor
    • \n
    • weight: tokenwise weight
    • \n
    \n\n
    Returns:
    \n\n
    \n

    out: output tensor

    \n
    \n", "signature": "(x, weight=None, out=None):", "funcdef": "def"}, {"fullname": "linghe.utils.silu.triton_weighted_silu_backward", "modulename": "linghe.utils.silu", "qualname": "triton_weighted_silu_backward", "kind": "function", "doc": "

    backward of triton_weighted_silu_forward

    \n\n
    Arguments:
    \n\n
      \n
    • g: gradient tensor
    • \n
    • x: input tensor
    • \n
    • weight: weight tensor
    • \n
    \n\n
    Returns:
    \n\n
    \n
      \n
    • dx: gradient of x
    • \n
    • dw: gradient of weight
    • \n
    \n
    \n", "signature": "(\tg: torch.Tensor,\tx: torch.Tensor,\tweight: Optional[torch.Tensor] = None):", "funcdef": "def"}, {"fullname": "linghe.utils.silu.triton_silu_and_block_quant_forward", "modulename": "linghe.utils.silu", "qualname": "triton_silu_and_block_quant_forward", "kind": "function", "doc": "

    fused silu and blockwise quantization, used in shared expert

    \n\n
    Arguments:
    \n\n
      \n
    • x: input tensor
    • \n
    • round_scale: whether round scale to power of 2
    • \n
    • output_mode: one of {0, 1, 2}\n0: only output non-transposed quantized tensor\n1: only output transposed quantized tensor\n2: output both
    • \n
    \n\n
    Returns:
    \n\n
    \n
      \n
    • out: quantized tensor
    • \n
    • scale: quantization scale
    • \n
    • transpose_output: quantized tensor of transposed output
    • \n
    • transpose_scale: quantization scale of transposed output
    • \n
    \n
    \n", "signature": "(x, out=None, scale=None, round_scale=False, output_mode=2):", "funcdef": "def"}, {"fullname": "linghe.utils.silu.triton_silu_and_block_quant_backward", "modulename": "linghe.utils.silu", "qualname": "triton_silu_and_block_quant_backward", "kind": "function", "doc": "

    backward of triton_silu_and_block_quant_forward

    \n\n
    Arguments:
    \n\n
      \n
    • g: gradient
    • \n
    • x: input tensor
    • \n
    • round_scale: whether round to power of 2
    • \n
    \n\n
    Returns:
    \n\n
    \n
      \n
    • dx: quantized non-transposed gradient
    • \n
    • dx_scale: scales of quantization non-transposed gradient
    • \n
    • transpose_dx: quantized transposed gradient
    • \n
    • transpose_dx_scale: scales of quantization transposed gradient
    • \n
    \n
    \n", "signature": "(g, x, round_scale=False):", "funcdef": "def"}, {"fullname": "linghe.utils.silu.triton_batch_weighted_silu_and_block_quant_forward", "modulename": "linghe.utils.silu", "qualname": "triton_batch_weighted_silu_and_block_quant_forward", "kind": "function", "doc": "

    silu and blockwise quantize activation in routed experts

    \n\n
    Arguments:
    \n\n
      \n
    • x: activation tensor in routed experts
    • \n
    • weight: router prob tensor
    • \n
    • counts: cuda tensor of token count per expert
    • \n
    • splits: python int list of token count per expert
    • \n
    • round_scale: whether round scale to power of 2
    • \n
    • output_mode: one of {0, 1, 2}\n0: only output non-transposed quantized tensor\n1: only output transposed quantized tensor\n2: output both
    • \n
    \n\n
    Returns:
    \n\n
    \n
      \n
    • out: quantized tensor
    • \n
    • scale: quantization scale
    • \n
    • transpose_output: quantized tensor of transposed output
    • \n
    • transpose_scale: quantization scale of transposed output
    • \n
    \n
    \n", "signature": "(\tx,\tweight,\tcounts,\tsplits=None,\tout=None,\tscale=None,\tround_scale=False,\toutput_mode=2):", "funcdef": "def"}, {"fullname": "linghe.utils.silu.triton_batch_weighted_silu_and_block_quant_backward", "modulename": "linghe.utils.silu", "qualname": "triton_batch_weighted_silu_and_block_quant_backward", "kind": "function", "doc": "

    backward of triton_batch_weighted_silu_and_block_quant_forward

    \n\n
    Arguments:
    \n\n
      \n
    • g: gradient
    • \n
    • x: input tensor
    • \n
    • weight: router prob tensor
    • \n
    • counts: cuda tensor of token count per expert
    • \n
    • splits: python int list of token count per expert
    • \n
    • round_scale: whether round scale to power of 2
    • \n
    \n\n
    Returns:
    \n\n
    \n
      \n
    • dx: quantized non-transposed gradient
    • \n
    • dx_scale: scales of quantization non-transposed gradient
    • \n
    • dw: gradient of weight
    • \n
    • transpose_dx: quantized transposed gradient
    • \n
    • transpose_dx_scale: scales of quantization transposed gradient
    • \n
    \n
    \n", "signature": "(g, x, weight, counts, splits=None, round_scale=False):", "funcdef": "def"}, {"fullname": "linghe.utils.transpose", "modulename": "linghe.utils.transpose", "kind": "module", "doc": "

    Copyright (c) Ant Financial Service Group and its affiliates.

    \n"}, {"fullname": "linghe.utils.transpose.triton_transpose", "modulename": "linghe.utils.transpose", "qualname": "triton_transpose", "kind": "function", "doc": "

    transpose x with dim0 and dim1

    \n\n
    Arguments:
    \n\n
      \n
    • x: input tensor
    • \n
    • dim0: dim 0
    • \n
    • dim1: dim 1
    • \n
    \n\n
    Returns:
    \n\n
    \n

    transposed tensor

    \n
    \n", "signature": "(\tx: torch.Tensor,\tdim0: Optional[int] = None,\tdim1: Optional[int] = None):", "funcdef": "def"}, {"fullname": "linghe.utils.transpose.triton_transpose_and_pad", "modulename": "linghe.utils.transpose", "qualname": "triton_transpose_and_pad", "kind": "function", "doc": "

    transpose x and padding the column size to be mutiplier of 32,\nit is used for calculated gradient of weight with torch._scaled__mm

    \n\n
    Arguments:
    \n\n
      \n
    • x: input tensor
    • \n
    • out:
    • \n
    • pad: whether need padding
    • \n
    \n\n
    Returns:
    \n\n
    \n

    out: output tensor

    \n
    \n", "signature": "(x, out=None, pad=True):", "funcdef": "def"}, {"fullname": "linghe.utils.transpose.triton_batch_transpose", "modulename": "linghe.utils.transpose", "qualname": "triton_batch_transpose", "kind": "function", "doc": "

    batch transpose x

    \n\n
    Arguments:
    \n\n
      \n
    • xs: input tensor list, [M, N]*expert
    • \n
    \n\n
    Returns:
    \n\n
    \n

    xts: output tensor list, [N,M]*expert

    \n
    \n", "signature": "(xs, xts=None):", "funcdef": "def"}, {"fullname": "linghe.utils.transpose.triton_batch_transpose_and_pad", "modulename": "linghe.utils.transpose", "qualname": "triton_batch_transpose_and_pad", "kind": "function", "doc": "

    transpose and pad each tensor stored in x

    \n\n
    Arguments:
    \n\n
      \n
    • x: [sum(bs), N]
    • \n
    • count_list: a python list of token count
    • \n
    • pad: whether pad to mutiplier of 32,\npadding value should be filled with 0 if padded
    • \n
    \n\n
    Returns:
    \n\n
    \n

    x_t: output tensor

    \n
    \n", "signature": "(x, count_list, x_t=None, pad=True):", "funcdef": "def"}]; // mirrored in build-search-index.js (part 1) // Also split on html tags. this is a cheap heuristic, but good enough. diff --git a/linghe/facade/add.py b/linghe/facade/add.py index c40aca4..cac4306 100644 --- a/linghe/facade/add.py +++ b/linghe/facade/add.py @@ -9,9 +9,7 @@ class InplaceAddFunction(torch.autograd.Function): - """ - - """ + """""" @staticmethod def forward(ctx, x: torch.Tensor, y: torch.Tensor): return triton_inplace_add(x, y) @@ -28,6 +26,6 @@ def inplace_add(x: torch.Tensor, y: torch.Tensor): x: to be updated y: add to x Returns: - return updated x tensor + updated x tensor """ return InplaceAddFunction.apply(x, y) \ No newline at end of file diff --git a/linghe/facade/fp32_gemm.py b/linghe/facade/fp32_gemm.py index 8ab9ae9..2f6ff34 100644 --- a/linghe/facade/fp32_gemm.py +++ b/linghe/facade/fp32_gemm.py @@ -11,9 +11,7 @@ class Fp32GEMM(torch.autograd.Function): - """ - - """ + """""" @staticmethod def forward(ctx, input: torch.Tensor, weight: torch.Tensor): shape = input.shape diff --git a/linghe/facade/hadamard_quant_linear.py b/linghe/facade/hadamard_quant_linear.py index 586f104..5b3dd45 100644 --- a/linghe/facade/hadamard_quant_linear.py +++ b/linghe/facade/hadamard_quant_linear.py @@ -34,8 +34,8 @@ def forward( output = torch._scaled_mm(x_q, w_q.t(), - scale_a=x_scale, - scale_b=w_scale, + scale_a=x_scale.view(-1,1), + scale_b=w_scale.view(1,-1), out_dtype=ctx.out_dtype, use_fast_accum=True ) @@ -61,7 +61,6 @@ def backward( output_grad: torch.Tensor, ): xt_q, xt_scale, wt_q, wt_scale, hadamard_matrix = ctx.saved_tensors - results = [None, None, None, None] output_grad = output_grad.view(-1, output_grad.shape[-1]) @@ -69,32 +68,33 @@ def backward( dx = torch._scaled_mm(y_q, wt_q.t(), - scale_a=y_scale, - scale_b=wt_scale, + scale_a=y_scale.view(-1,1), + scale_b=wt_scale.view(1,-1), out_dtype=ctx.out_dtype, use_fast_accum=True ) - # calculate input grad and assign to results[0] - results[0] = dx.view(ctx.input_shape) + dx = dx.view(ctx.input_shape) - # calculate weight grad and assign to results[1] dw = torch._scaled_mm(yt_q, xt_q.t(), - scale_a=yt_scale, - scale_b=xt_scale, + scale_a=yt_scale.view(-1,1), + scale_b=xt_scale.view(1,-1), out_dtype=ctx.out_dtype, use_fast_accum=True ) - results[1] = dw + db = None if ctx.bias_requires_grad: - # calculate bias grad and assign to results[2] - results[2] = torch.sum(output_grad, dim=0) + db = torch.sum(output_grad, dim=0) + + return dx, dw, db, None - return tuple(results) class HadamardQuantLinear(torch.nn.Module): + """ + a naive implementation of hadamard transformation and quantization + """ def __init__( self, in_features: int, @@ -104,14 +104,12 @@ def __init__( dtype=None ): """ - a naive implementation of hadamard transformation and quantization Args: in_features: in feature number out_features: out feature number bias: whether use bias device: weight device dtype: weight dtype - impl: implementation of hadamard quantization """ super().__init__() self.in_features = in_features @@ -145,6 +143,7 @@ def _hadamard_matrix(self, size, device=None, dtype=None, norm=False): return m def forward(self, input: torch.Tensor) -> torch.Tensor: + """""" if self.training: return _HadamardQuantLinear.apply(input, self.weight, self.bias, self.hadamard_matrix) @@ -155,9 +154,11 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: return output def extra_repr(self) -> str: + """""" return f"in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}" def reset_parameters(self): + """""" self.weight.data.normal_(mean=0.0, std=0.02) if self.bias is not None: self.bias.data.zero_() diff --git a/linghe/facade/loss.py b/linghe/facade/loss.py index 1fa7294..0feac8a 100644 --- a/linghe/facade/loss.py +++ b/linghe/facade/loss.py @@ -10,9 +10,7 @@ class SoftmaxCrossEntropyFunction(torch.autograd.Function): - """ - - """ + """""" @staticmethod def forward(ctx, logits, labels, inplace=False): shape = logits.shape @@ -48,9 +46,8 @@ def softmax_cross_entropy(logits: torch.Tensor, labels: torch.Tensor, inplace: b logits: logits tensor, shape [...,dim] labels: labels tensor, shape [...] inplace: update gradient in the `logits` tensor if True - Returns: - per token loss + a tensor of per token loss """ assert logits.is_contiguous() assert labels.is_contiguous() @@ -58,9 +55,7 @@ def softmax_cross_entropy(logits: torch.Tensor, labels: torch.Tensor, inplace: b class GradScalingFunction(torch.autograd.Function): - """ - - """ + """""" @staticmethod def forward(ctx, x, coef=0.2): ctx.coef = coef diff --git a/linghe/facade/norm.py b/linghe/facade/norm.py index 435942f..9dc90e5 100644 --- a/linghe/facade/norm.py +++ b/linghe/facade/norm.py @@ -10,9 +10,7 @@ class RMSNormFunction(torch.autograd.Function): - """ - - """ + """""" @staticmethod def forward(ctx, x, weight, eps=1e-6): output = triton_rms_norm_forward( @@ -56,9 +54,7 @@ def rms_norm(x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6): return RMSNormFunction.apply(x, weight, eps) class GroupNormGateFunction(torch.autograd.Function): - """ - - """ + """""" @staticmethod def forward(ctx, attn_output, gate, weight, eps=1e-6, group_size=4): output = triton_group_norm_gate_forward( diff --git a/linghe/facade/rope.py b/linghe/facade/rope.py index 3719095..f845183 100644 --- a/linghe/facade/rope.py +++ b/linghe/facade/rope.py @@ -10,9 +10,7 @@ class QkNormHalfRopeFunction(torch.autograd.Function): - """ - - """ + """""" @staticmethod def forward(ctx, qkv, q_norm_weight, k_norm_weight, freqs, H=32, h=4, eps=1e-6): @@ -71,9 +69,9 @@ def qk_norm_half_rope(qkv: torch.Tensor, eps: epsilon value for L2 normalization. Returns: - qo: shape [B, S, H, head_dim] - ko: shape [B, S, h, head_dim] - vo: shape [B, S, h, head_dim] + - qo: shape [B, S, H, head_dim] + - ko: shape [B, S, h, head_dim] + - vo: shape [B, S, h, head_dim] """ return QkNormHalfRopeFunction.apply(qkv, q_norm_weight, diff --git a/linghe/facade/smooth_quant_linear.py b/linghe/facade/smooth_quant_linear.py index fccbabb..dd284e2 100644 --- a/linghe/facade/smooth_quant_linear.py +++ b/linghe/facade/smooth_quant_linear.py @@ -19,19 +19,22 @@ def forward( ctx, input: torch.Tensor, weight: torch.Tensor, + bias: Optional[torch.Tensor], smooth_scale: torch.Tensor, - bias: Optional[torch.Tensor] ): ctx.input_requires_grad = input.requires_grad ctx.weight_requires_grad = weight.requires_grad ctx.bias_requires_grad = bias is not None and bias.requires_grad - ctx.out_dtype = input.dtype ctx.input_shape = input.shape + + round_scale = True + ctx.round_scale = round_scale + input = input.view(-1, input.shape[-1]) - x_q, x_scale, x_maxs = triton_smooth_quant(input, 1 / smooth_scale) - w_q, w_scale, w_maxs = triton_smooth_quant(weight, smooth_scale) + x_q, x_scale, x_maxs = triton_smooth_quant(input, 1 / smooth_scale, round_scale=round_scale) + w_q, w_scale, w_maxs = triton_smooth_quant(weight, smooth_scale, round_scale=round_scale) output = torch._scaled_mm(x_q, w_q.t(), @@ -60,26 +63,28 @@ def backward( ctx, output_grad: torch.Tensor ): + x_q, x_s, w_q, w_s, smooth_scale = ctx.saved_tensors - results = [None, None, None, None] output_grad = output_grad.view(-1, output_grad.shape[-1]) - - y_q, y_scale, y_maxs = triton_smooth_quant(output_grad, w_s) + round_scale = ctx.round_scale + y_q, y_scale, y_maxs = triton_smooth_quant(output_grad, + w_s, + reverse=True, + round_scale=round_scale) wt_q = triton_transpose_and_pad(w_q, pad=True) dx = torch._scaled_mm(y_q, - wt_q.t(), - scale_a=y_scale.view(-1, 1), - scale_b=smooth_scale.view(1, -1), - out_dtype=ctx.out_dtype, - use_fast_accum=True) + wt_q.t(), + scale_a=y_scale.view(-1, 1), + scale_b=smooth_scale.view(1, -1), + out_dtype=ctx.out_dtype, + use_fast_accum=True) - # calculate input grad and assign to results[0] - results[0] = dx.view(ctx.input_shape) - - # calculate weight grad and assign to results[1] - yt_q, yt_scale, yt_maxs = triton_transpose_smooth_quant(output_grad, x_s) + yt_q, yt_scale = triton_transpose_smooth_quant(output_grad, + x_s, + reverse=True , + round_scale=round_scale) xt_q = triton_transpose_and_pad(x_q, pad=True) dw = torch._scaled_mm(yt_q, @@ -89,16 +94,17 @@ def backward( out_dtype=ctx.out_dtype, use_fast_accum=True) - results[1] = dw - + db = None if ctx.bias_requires_grad: - # calculate bias grad and assign to results[2] - results[2] = torch.sum(output_grad, dim=0) + db = torch.sum(output_grad, dim=0) - return tuple(results) + return dx, dw, db, None -class QuantLinear(torch.nn.Module): +class SmoothQuantLinear(torch.nn.Module): + """ + a naive implementation of smooth quantization linear + """ def __init__( self, in_features: int, @@ -107,6 +113,14 @@ def __init__( device=None, dtype=None ): + """ + Args: + in_features: in feature number + out_features: out feature number + bias: whether use bias + device: weight device + dtype: weight dtype + """ super().__init__() self.in_features = in_features self.out_features = out_features @@ -120,13 +134,13 @@ def __init__( self.bias = None self.gap_step = 16 - self.decay_coef = 0.9 self.smooth_scale = None self.smooth_update_step = 0 self.reset_parameters() def forward(self, input: torch.Tensor) -> torch.Tensor: + """""" if self.training: if self.smooth_update_step % self.gap_step == 0: @@ -134,10 +148,10 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: weight_maxs = triton_abs_max(self.weight.data) self.smooth_scale = torch.sqrt(input_maxs * weight_maxs) - output, smooth_scale = _SmoothQuantLinear.apply(input, - self.weight, - self.bias, - self.smooth_scale) + output = _SmoothQuantLinear.apply(input, + self.weight, + self.bias, + self.smooth_scale) self.smooth_update_step += 1 else: output = input @ self.weight.t() @@ -146,9 +160,11 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: return output def extra_repr(self) -> str: + """""" return f"in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}" def reset_parameters(self): + """""" self.weight.data.normal_(mean=0.0, std=0.02) if self.bias is not None: self.bias.data.zero_() diff --git a/linghe/facade/transpose.py b/linghe/facade/transpose.py index 9d8de83..9332e18 100644 --- a/linghe/facade/transpose.py +++ b/linghe/facade/transpose.py @@ -9,9 +9,7 @@ class TransposeDim01Function(torch.autograd.Function): - """ - - """ + """""" @staticmethod def forward(ctx, x): return triton_transpose(x, dim0=0, dim1=1) diff --git a/linghe/gemm/blockwise_fp8_gemm.py b/linghe/gemm/blockwise_fp8_gemm.py index da9416a..f97a1d0 100644 --- a/linghe/gemm/blockwise_fp8_gemm.py +++ b/linghe/gemm/blockwise_fp8_gemm.py @@ -72,6 +72,7 @@ def triton_bb_fp8_gemm(a: torch.Tensor, b_s: torch.Tensor, out_dtype=torch.bfloat16, block_size=128): + """""" assert a.is_contiguous() and b.is_contiguous() assert a_s.is_contiguous() and b_s.is_contiguous() K = a.size(-1) @@ -155,6 +156,7 @@ def triton_tb_fp8_gemm(a: torch.Tensor, b_s: torch.Tensor, out_dtype=torch.bfloat16, block_size=128): + """""" assert a.is_contiguous() and b.is_contiguous() assert a_s.is_contiguous() and b_s.is_contiguous() K = a.size(-1) @@ -227,6 +229,7 @@ def triton_tt_fp8_gemm(a: torch.Tensor, b_s: torch.Tensor, out_dtype=torch.bfloat16, block_size=128): + """""" assert a.is_contiguous() and b.is_contiguous() assert a_s.is_contiguous() and b_s.is_contiguous() K = a.size(-1) diff --git a/linghe/gemm/fp32_gemm.py b/linghe/gemm/fp32_gemm.py index 8f44067..5ad778c 100644 --- a/linghe/gemm/fp32_gemm.py +++ b/linghe/gemm/fp32_gemm.py @@ -297,7 +297,7 @@ def triton_scaled_fp32_gemm(a: torch.Tensor, scale: scale for activation tensor, 1/rms Returns: - + output tensor """ assert a.is_contiguous() and b.is_contiguous() M, K = a.size() diff --git a/linghe/quant/block.py b/linghe/quant/block.py index cfb37fa..e7d093f 100644 --- a/linghe/quant/block.py +++ b/linghe/quant/block.py @@ -39,8 +39,8 @@ def triton_block_quant(x, round_scale: whether round scale to power of 2 Returns: - y: quantized tensor, float8_e4m3fn - s: quantization scale, float32 + - y: quantized tensor, float8_e4m3fn + - s: quantization scale, float32 """ M, N = x.size() y = torch.empty((M, N), dtype=torch.float8_e4m3fn, device=x.device) diff --git a/linghe/quant/channel.py b/linghe/quant/channel.py index 01a9571..462b6cb 100644 --- a/linghe/quant/channel.py +++ b/linghe/quant/channel.py @@ -199,8 +199,8 @@ def triton_transpose_row_quant(x, round_scale=False): round_scale: whether round scale to power of 2 Returns: - x_q: quantized tensor - x_scale: quantization scale + - x_q: quantized tensor + - x_scale: quantization scale """ M, N = x.shape @@ -270,8 +270,3 @@ def channel_quant_update(y, x): use_fast_accum=True) return output, y_q, x_q, y_scale, x_scale - -def fp8_channel_f_and_b(x, w, y): - channel_quant_forward(x, w) - channel_quant_backward(y, w) - channel_quant_update(y, x) diff --git a/linghe/quant/group.py b/linghe/quant/group.py index 9dec9b8..b46ce47 100644 --- a/linghe/quant/group.py +++ b/linghe/quant/group.py @@ -42,8 +42,8 @@ def triton_group_quant(x, round_scale: whether round scale to power of 2 Returns: - y: quantized tensor, float8_e4m3fn - s: quantization scale, float32 + - y: quantized tensor, float8_e4m3fn + - s: quantization scale, float32 """ M, N = x.shape K = 16 diff --git a/linghe/quant/hadamard.py b/linghe/quant/hadamard.py index 1bd77bc..0c59f96 100644 --- a/linghe/quant/hadamard.py +++ b/linghe/quant/hadamard.py @@ -125,10 +125,10 @@ def triton_hadamard_quant(x, hm): x: input tensor hm: hamadard matrix Returns: - x_q: rowwise quantized tensor of non-transposed x - x_scale: rowwise quantization scale of non-transposed x - xt_q: columnwise quantized tensor of transposed x - xt_scale: columnwise quantization scale of transposed x + - x_q: rowwise quantized tensor of non-transposed x + - x_scale: rowwise quantization scale of non-transposed x + - xt_q: columnwise quantized tensor of transposed x + - xt_scale: columnwise quantization scale of transposed x """ M, N = x.shape device = x.device diff --git a/linghe/quant/smooth.py b/linghe/quant/smooth.py index 4844511..93ac054 100644 --- a/linghe/quant/smooth.py +++ b/linghe/quant/smooth.py @@ -150,9 +150,7 @@ def blockwise_smooth_quant_kernel(x_ptr, q_ptr, ss_ptr, qs_ptr, max_ptr, def triton_smooth_quant(x, smooth_scale, x_q=None, x_scale=None, reverse=False, round_scale=False, calibrate=False): - """ - - """ + """""" M, N = x.shape device = x.device if x_q is None: @@ -291,9 +289,7 @@ def subrow_smooth_quant_kernel(x_ptr, q_ptr, ss_ptr, qs_ptr, def triton_subrow_smooth_quant(x, smooth_scale, x_q, x_scale, subrow_scales, offset, size, reverse=False, round_scale=False): - """ - - """ + """""" M, N = x_q.shape W = 128 if offset % N == 0: @@ -369,9 +365,7 @@ def depracated_tokenwise_smooth_quant_kernel(x_ptr, q_ptr, ss_ptr, def triton_depracated_tokenwise_smooth_quant(x, smooth_scale, x_q=None, x_scale=None, reverse=False, round_scale=False): - """ - - """ + """""" # row-wise read, row-wise write M, N = x.shape device = x.device @@ -455,9 +449,7 @@ def triton_batch_smooth_quant(x, smooth_scales, token_count_per_expert, x_q=None, x_scale=None, x_maxs=None, reverse=False, round_scale=False, calibrate=False): - """ - - """ + """""" M, N = x.shape device = x.device n_expert = token_count_per_expert.shape[0] @@ -578,9 +570,7 @@ def triton_batch_pad_transpose_smooth_quant(x, splits, x_q=None, x_scale=None, x_maxs=None, reverse=False, round_scale=False): - """ - - """ + """""" M, N = x.shape device = x.device n_expert = token_count_per_expert.shape[0] @@ -694,9 +684,7 @@ def triton_transpose_smooth_quant(x, round_scale=False): # col-wise read, row-wise write # M should be padded if M % 32 != 0 - """ - - """ + """""" M, N = x.shape device = x.device P = (M + 31) // 32 * 32 if pad else M @@ -823,9 +811,7 @@ def triton_transpose_rescale_smooth_quant(x_q, org_smooth_scale, reverse=True, pad=False, round_scale=False): - """ - - """ + """""" assert reverse M, N = x_q.shape device = x_q.device @@ -886,9 +872,7 @@ def triton_transpose_rescale_smooth_quant(x_q, org_smooth_scale, # dwT = yT @ x def triton_smooth_quant_input(x, smooth_scale, x_q=None, x_scale=None, xt_q=None, transpose=True, pad=True, round_scale=False): - """ - - """ + """""" x_q, x_scale, x_maxs = triton_smooth_quant(x, smooth_scale, x_q=x_q, x_scale=x_scale, reverse=False, round_scale=round_scale) @@ -912,9 +896,7 @@ def triton_smooth_quant_gradient(y, transpose=True, pad=True, round_scale=False): - """ - - """ + """""" assert reverse, ("args `smooth_scale` and/or `transpose_smooth_scale` " "must be in reciprocal format in triton_smooth_quant_grad") y_q, y_scale, _ = triton_smooth_quant(y, smooth_scale, reverse=True, @@ -937,9 +919,7 @@ def triton_smooth_quant_weight(w, quant_scale, subrow_scales, offset=0, round_scale=False): - """ - - """ + """""" assert w.ndim == 1 assert w_q.size(1) == smooth_scale.size(0) diff --git a/linghe/utils/gather.py b/linghe/utils/gather.py index abba7e4..642db26 100644 --- a/linghe/utils/gather.py +++ b/linghe/utils/gather.py @@ -843,8 +843,7 @@ def triton_smooth_permute_with_mask_map( round_scale=False ): """ - gather and optional dequant and smooth quant - + gather ( and optional dequant) and smooth quant Args: inp: [num_tokens, hidden_size], rowwise quantized tensor row_id_map: [n_experts, num_tokens], indices @@ -858,7 +857,8 @@ def triton_smooth_permute_with_mask_map( round_scale: Returns: - + - output: output tensor + - permuted_scale: permuted scale if scale is not None """ assert row_id_map.shape[1] == num_experts output = torch.empty((num_out_tokens, hidden_size), diff --git a/linghe/utils/norm.py b/linghe/utils/norm.py index c55745c..6207cc6 100644 --- a/linghe/utils/norm.py +++ b/linghe/utils/norm.py @@ -288,11 +288,12 @@ def triton_rms_norm_and_block_quant_forward(x: torch.Tensor, 1: only output transposed tensor 2: return both Returns: - out: quantization data - scale: quantization scale - rms: Reciprocal of the root mean square of the input calculated over the last dimension. - transpose_output: quantization data of transposed gradient - transpose_scale: quantization scale of transposed gradient + - out: quantization data. + - scale: quantization scale. + - rms: Reciprocal of the root mean square of the + input calculated over the last dimension. + - transpose_output: quantization data of transposed gradient. + - transpose_scale: quantization scale of transposed gradient. """ # row-wise read, row-wise write M, N = x.shape @@ -581,9 +582,7 @@ def triton_rms_norm_and_smooth_quant_forward(x, weight, smooth_scale=None, calibrate=False, output_rms=False, round_scale=False): - """ - - """ + """""" M, N = x.shape assert N <= 8192 and 8192 % N == 0 device = x.device diff --git a/linghe/utils/rearange.py b/linghe/utils/rearange.py index 58e5d6c..868814c 100644 --- a/linghe/utils/rearange.py +++ b/linghe/utils/rearange.py @@ -43,8 +43,8 @@ def triton_split_and_cat(x, counts, indices, scales=None): scales: [bs] Returns: - y: output tensor - output_scales: output scales if scales is not None + - y: output tensor + - output_scales: output scales if scales is not None """ M, N = x.shape n_split = counts.shape[0] diff --git a/linghe/utils/rope.py b/linghe/utils/rope.py index 85f5696..882b014 100644 --- a/linghe/utils/rope.py +++ b/linghe/utils/rope.py @@ -89,8 +89,8 @@ def triton_half_rope_forward(q, k, freqs): freqs: rope freqs Returns: - qo: - ko: + - qo: query output + - ko: key output """ L, B, H, D = q.shape h = k.shape[2] @@ -340,9 +340,9 @@ def triton_qk_norm_and_half_rope_forward(qkv, q_norm_weight, k_norm_weight, transpose: whether qkv is tranposed, i.e., [S, B, dim], only support transpose format currently Returns: - qo: shape [B, S, H, head_dim] - ko: shape [B, S, h, head_dim] - vo: shape [B, S, h, head_dim] + - qo: shape [B, S, H, head_dim] + - ko: shape [B, S, h, head_dim] + - vo: shape [B, S, h, head_dim] """ assert transpose @@ -560,9 +560,9 @@ def triton_qk_norm_and_half_rope_backward(gq, gk, gv, qkv, q_norm_weight, interleave: Returns: - dqkv: gradient of qkv - dqw: gradient of q_norm_weight - dkw: gradient of k_norm_weight + - dqkv: gradient of qkv + - dqw: gradient of q_norm_weight + - dkw: gradient of k_norm_weight """ assert transpose B, L, H, D = gq.shape diff --git a/linghe/utils/scatter.py b/linghe/utils/scatter.py index bb39945..2a4e465 100644 --- a/linghe/utils/scatter.py +++ b/linghe/utils/scatter.py @@ -104,7 +104,7 @@ def triton_scatter_add(x, outputs, indices): indices: indices Returns: - outputs + output tensor """ M, N = x.shape @@ -186,8 +186,8 @@ def triton_unpermute_with_mask_map( probs: [num_out_tokens] Returns: - output: [num_tokens, hidden_size] - restore_probs: [num_tokens, num_experts] + - output: [num_tokens, hidden_size] + - restore_probs: [num_tokens, num_experts] """ hidden_size = grad.shape[1] num_tokens, num_experts = row_id_map.shape # not transposed diff --git a/linghe/utils/silu.py b/linghe/utils/silu.py index bfcf910..90545b3 100644 --- a/linghe/utils/silu.py +++ b/linghe/utils/silu.py @@ -122,8 +122,8 @@ def triton_weighted_silu_backward(g: torch.Tensor, weight: weight tensor Returns: - dx: gradient of x - dw: gradient of weight + - dx: gradient of x + - dw: gradient of weight """ # row-wise read, row-wise write M, N = x.shape @@ -228,10 +228,10 @@ def triton_silu_and_block_quant_forward(x, 2: output both Returns: - out: quantized tensor - scale: quantization scale - transpose_output: quantized tensor of transposed output - transpose_scale: quantization scale of transposed output + - out: quantized tensor + - scale: quantization scale + - transpose_output: quantized tensor of transposed output + - transpose_scale: quantization scale of transposed output """ M, N = x.shape n = N // 2 @@ -349,10 +349,10 @@ def triton_silu_and_block_quant_backward(g, x, round_scale: whether round to power of 2 Returns: - dx: quantized non-transposed gradient - dx_scale: scales of quantization non-transposed gradient - transpose_dx: quantized transposed gradient - transpose_dx_scale: scales of quantization transposed gradient + - dx: quantized non-transposed gradient + - dx_scale: scales of quantization non-transposed gradient + - transpose_dx: quantized transposed gradient + - transpose_dx_scale: scales of quantization transposed gradient """ M, N = x.shape n = N // 2 @@ -480,10 +480,10 @@ def triton_batch_weighted_silu_and_block_quant_forward(x, 2: output both Returns: - out: quantized tensor - scale: quantization scale - transpose_output: quantized tensor of transposed output - transpose_scale: quantization scale of transposed output + - out: quantized tensor + - scale: quantization scale + - transpose_output: quantized tensor of transposed output + - transpose_scale: quantization scale of transposed output """ M, N = x.shape n = N // 2 @@ -642,11 +642,11 @@ def triton_batch_weighted_silu_and_block_quant_backward(g, x, weight, splits: python int list of token count per expert round_scale: whether round scale to power of 2 Returns: - dx: quantized non-transposed gradient - dx_scale: scales of quantization non-transposed gradient - dw: gradient of weight - transpose_dx: quantized transposed gradient - transpose_dx_scale: scales of quantization transposed gradient + - dx: quantized non-transposed gradient + - dx_scale: scales of quantization non-transposed gradient + - dw: gradient of weight + - transpose_dx: quantized transposed gradient + - transpose_dx_scale: scales of quantization transposed gradient """ # row-wise read, row-wise write M, N = x.shape @@ -793,9 +793,7 @@ def compatible_silu_and_smooth_quant_forward_kernel(x_ptr, smooth_scale_ptr, out def triton_silu_and_smooth_quant_forward(x, smooth_scale=None, out=None, scale=None, maxs=None, round_scale=False, calibrate=False): - """ - - """ + """""" M, N = x.shape n = N // 2 device = x.device @@ -995,9 +993,7 @@ def triton_silu_and_smooth_quant_backward(g, x, transpose_smooth_scale=None, reverse=True, round_scale=False): - """ - - """ + """""" assert round_scale M, N = x.shape n = N // 2 @@ -1112,9 +1108,7 @@ def triton_batch_weighted_silu_and_smooth_quant_forward(x, round_scale=False, reverse=False, calibrate=False): - """ - - """ + """""" M, N = x.shape n = N // 2 n_experts = counts.shape[0] @@ -1365,9 +1359,7 @@ def triton_batch_weighted_silu_and_smooth_quant_backward(g, x, weight, splits=None, reverse=True, round_scale=False): - """ - - """ + """""" assert round_scale M, N = x.shape n = N // 2 diff --git a/tests/test_hadamard_quant.py b/tests/test_hadamard_quant.py index c4c6a57..f508e2f 100644 --- a/tests/test_hadamard_quant.py +++ b/tests/test_hadamard_quant.py @@ -12,7 +12,7 @@ torch_hadamard_transform, torch_row_quant, ) - +from linghe.facade.hadamard_quant_linear import HadamardQuantLinear @@ -73,6 +73,28 @@ def test_hadamard_quant(M=8192, N=1024, K=2048, B=64, bench=False): output_check(dyst, dyt_scale, 'dyt.scale') +def test_hadamard_quant_linear(M=8192, N=1024, K=2048, B=64): + + dtype = torch.bfloat16 + device = 'cuda:0' + linear = HadamardQuantLinear(K, N, bias=False, dtype=dtype, device=device) + x = torch.randn((M, K), dtype=dtype, device=device).requires_grad_() + w = torch.randn((N, K), dtype=dtype, device=device) + dy = torch.randn((M, N), dtype=dtype, device=device) + linear.weight.data.copy_(w) + + y_ref = x@w.t() + y = linear(x) + output_check(y_ref, y, mode='y') + + dx_ref = dy@w + dw_ref = dy.t()@x + y.backward(dy) + dw = linear.weight.grad + dx = x.grad + output_check(dx_ref, dx, mode='dx') + output_check(dw_ref, dw, mode='dw') if __name__ == '__main__': - test_hadamard_quant(M=8192, N=1024, K=2048, B=64, bench=False) \ No newline at end of file + test_hadamard_quant(M=8192, N=1024, K=2048, B=64, bench=False) + test_hadamard_quant_linear(M=8192, N=1024, K=2048, B=64) \ No newline at end of file diff --git a/tests/test_smooth_quant.py b/tests/test_smooth_quant.py index 4ddc166..b0011d9 100644 --- a/tests/test_smooth_quant.py +++ b/tests/test_smooth_quant.py @@ -15,7 +15,7 @@ torch_make_indices, torch_smooth_quant, round_up) - +from linghe.facade.smooth_quant_linear import SmoothQuantLinear def torch_split_smooth_quant(x_split, smooth_scales, round_scale=False): x_qs = [] @@ -294,6 +294,32 @@ def test_triton_batch_smooth_quant(M=4096, N=4096, n_experts=32, topk=8, n_repeat=n_repeat, ref_time=ref_time) + + +def test_smooth_quant_linear(M=8192, N=1024, K=2048): + + dtype = torch.bfloat16 + device = 'cuda:0' + linear = SmoothQuantLinear(K, N, bias=False, dtype=dtype, device=device) + x = (10*torch.randn((M, K), dtype=dtype, device=device)).requires_grad_() + w = 0.1*torch.randn((N, K), dtype=dtype, device=device) + dy = 1e-6*torch.randn((M, N), dtype=dtype, device=device) + linear.weight.data.copy_(w) + + y_ref = x@w.t() + y = linear(x) + output_check(y_ref, y, mode='y') + + dx_ref = dy@w + dw_ref = dy.t()@x + y.backward(dy) + dw = linear.weight.grad + dx = x.grad + output_check(dx_ref, dx, mode='dx') + output_check(dw_ref, dw, mode='dw') + + + if __name__ == '__main__': test_triton_smooth_quant(M=16384, N=2048, bench=False) test_triton_smooth_quant(M=8192, N=4096, bench=False) @@ -326,3 +352,4 @@ def test_triton_batch_smooth_quant(M=4096, N=4096, n_experts=32, topk=8, test_triton_batch_smooth_quant(M=4096, N=4096, n_experts=32, topk=8, round_scale=False) + test_smooth_quant_linear(M=8192, N=1024, K=2048) \ No newline at end of file