Skip to content

Commit

Permalink
docs: 初步添加 attention 算子定义
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <ydrml@hotmail.com>
  • Loading branch information
YdrMaster committed Jan 24, 2024
1 parent 39538e7 commit 594e06b
Showing 1 changed file with 36 additions and 0 deletions.
36 changes: 36 additions & 0 deletions src/08-01llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,39 @@ y = (x^2 + δ)^(-1/2) * w * x
1 Output:

- **Y(heterogeneous) - T**: 输出张量。形状与 `X` 相同。

## Attention

### Summary

Multi-head Self Attention 的封装形式,用于 transformer 模型。

支持使用 kv cache,使用条件由输入和属性综合决定。有以下 6 种情况:

| 序号 | 输入数量 | `max_seq_len` | 使用 kv cache | 输出数量 | cache s 维度 | 备注
|:-:|:-:|:-----:|:-------:|:-:|:------------------------:|:-
| 1 | 3 | 0 | none | 1 | - |
| 2 | 3 | S > 0 | init | 3 | `S` | `assert(S >= seq_len)`
| 3 | 4 | 0 | inplace | 3 | `past_seq_len + seq_len` | `past_seq_len` 必须是常量
| 4 | 4 | S > 0 | inplace | 3 | `S` | `assert(S >= past_seq_len + seq_len)`
| 5 | 6 | 0 | copy | 3 | `past_seq_len + seq_len` | `past_seq_len` 必须是常量
| 6 | 6 | S > 0 | copy | 3 | `S` | `assert(S >= past_seq_len + seq_len)`

### Attributes

- **max_seq_len - INT** (default is `0`): 最大序列长度,用于初始化 kv cache。

### Inputs

- **query(heterogeneous) - T**: 形状为 `N x n_head x seq_len x head_dim`
- **key(heterogeneous) - T**: 形状为 `N x n_kv_head x seq_len x head_dim`
- **value(heterogeneous) - T**: 形状为 `N x n_kv_head x seq_len x head_dim`
- **past_seq_len(optional) -int64**: 要连接的历史序列长度,必须为标量。不使用 kv cache 时留空。
- **k_cache(optional, heterogeneous) -T**: k 缓存的初始值,形状为 `N x n_kv_head x s x head_dim``s` 为不小于 `past_seq_len` 的任意值。不使用或不重置 kv cache 时留空。
- **v_cache(optional, heterogeneous) -T**: v 缓存的初始值,形状为 `N x n_kv_head x s x head_dim``s` 为不小于 `past_seq_len` 的任意值。不使用或不重置 kv cache 时留空。

### Outputs

- **output(heterogeneous) - T**: 形状与 `query` 相同。
- **k_cache(optional, heterogeneous) - T**: 形状为 `N x n_kv_head x s x head_dim``s` 的值根据 `Summary` 的描述计算。
- **v_cache(optional, heterogeneous) - T**: 形状为 `N x n_kv_head x s x head_dim``s` 的值根据 `Summary` 的描述计算。

0 comments on commit 594e06b

Please sign in to comment.