diff --git a/README.md b/README.md index ea37eb3..0f1e3fd 100644 --- a/README.md +++ b/README.md @@ -113,6 +113,11 @@ pip install collie-lm git clone https://github.com/OpenLMLab/collie python setup.py install ``` +### MOSS 2 (alpha) +如需使用 MOSS 2 (alpha),请安装 `triton-nightly`。命令如下: +```bash +pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly +``` ## Docker安装 diff --git a/collie/models/moss2alpha/__init__.py b/collie/models/moss2alpha/__init__.py new file mode 100644 index 0000000..a286e10 --- /dev/null +++ b/collie/models/moss2alpha/__init__.py @@ -0,0 +1 @@ +from .model import MOSS2ForCausalLM \ No newline at end of file diff --git a/collie/models/moss2alpha/configuration_moss2.py b/collie/models/moss2alpha/configuration_moss2.py new file mode 100644 index 0000000..86050fa --- /dev/null +++ b/collie/models/moss2alpha/configuration_moss2.py @@ -0,0 +1,167 @@ +# coding=utf-8 +# Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved. +# +# This code is based on transformers/src/transformers/models/llama/configuration_llama.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" MOSS2 model configuration""" +from typing import List, Optional, Tuple, Union +from transformers.configuration_utils import PretrainedConfig +from collie.log.logger import logger + + +MOSS2_PRETRAINED_CONFIG_ARCHIVE_MAP = {} + + +# Modified from transformers.model.llama.configuration_llama.LlamaConfig +class MOSS2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`MOSS2Model`]. It is used to instantiate + an MOSS2 model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the MOSS2-7B. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the MOSS2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`MOSS2Model`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings(`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + window_size_global (`int`, *optional*, defaults to `128`): + The global window size for the MOSS2 sparse attention. non positive values will disable the global attention. + window_size_left (`int`, *optional*, defaults to `1023`): + The left window size for the MOSS2 sparse attention. non positive values will disable sliding window, and + the attention mask will be causal regardless of the value of `window_size_global`. + `window_size_left` here has the same + meaning as `window_size[1]` in the original flash attention 2 implementation. For example, + if you want a sliding window of size 1024, you should set `window_size_left`=1023. + full_attention_layers (`List[int]`, *optional*, defaults to `None`): + The layers that will have full attention. If not specified, all layers will have sparse attention. + layer indices are 0-indexed. + Example: + + """ + model_type = "MOSS2" + _auto_class = "AutoConfig" + + def __init__( # pylint: disable=W0102 + self, + vocab_size=103168, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + bias=True, + rope_theta=10000, + rope_scaling=None, + attn_implementation="eager", + window_size_global=128, + window_size_left=1023, + full_attention_layers: Optional[List]=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.bias = bias + + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self._rope_scaling_validation() + + self.attn_implementation = attn_implementation + if self.attn_implementation is None: + self.attn_implementation = "eager" + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + self.window_size = (window_size_global, window_size_left, 0) + self.full_attention_layers = full_attention_layers + + def _rope_scaling_validation(self): + """ + Validate the `rope_scaling` configuration. + """ + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: + raise ValueError( + "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, " + f"got {self.rope_scaling}" + ) + rope_scaling_type = self.rope_scaling.get("type", None) + rope_scaling_factor = self.rope_scaling.get("factor", None) + if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: + raise ValueError( + f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" + ) + if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor < 1.0: + raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {rope_scaling_factor}") \ No newline at end of file diff --git a/collie/models/moss2alpha/model.py b/collie/models/moss2alpha/model.py new file mode 100644 index 0000000..293e37e --- /dev/null +++ b/collie/models/moss2alpha/model.py @@ -0,0 +1,1757 @@ +""" PyTorch MOSS2 model.""" +import math +import queue +import threading +import warnings +from typing import List, Optional, Tuple, Union +import gc +import json +import os +from collections import OrderedDict + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import torch.distributed as dist +from einops import rearrange +from torch import nn +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) +from transformers.modeling_utils import PreTrainedModel, dtype_byte_size +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, +) +from megatron.core import parallel_state, tensor_parallel +from deepspeed.pipe import LayerSpec, TiedLayerSpec + +from collie.config import CollieConfig +from collie.driver.io import IODriver +from collie.log.logger import logger +from collie.module import ( + ColumnParallelLinearWithoutBias, + ColumnParallelLMHead, + RowParallelLinearWithoutBias, +) +from collie.utils import concat_tensor, dict_as_params, env, progress +from collie.models.base import CollieModelForCausalLM +from collie.models.utils import ( + kv_cache_to_inputs_for_layer, inputs_to_kv_cache_for_layer, + kv_cache_to_inputs_for_model, inputs_to_kv_cache_for_model, +) + +try: + from transformers.generation.streamers import BaseStreamer +except: # noqa # pylint: disable=bare-except + BaseStreamer = None + +from sparse_varlen_kernel import flash_attn_varlen_func as sparse_varlen_func +from sparse_kernel import flash_attn_func as sparse_func + +from .configuration_moss2 import MOSS2Config + +_CONFIG_FOR_DOC = "MOSS2Config" + +flash_attn_func, flash_attn_varlen_func = None, None +pad_input, index_first_axis, unpad_input = None, None, None +def _import_flash_attn(): + global flash_attn_func, flash_attn_varlen_func + global pad_input, index_first_axis, unpad_input + try: + from flash_attn import flash_attn_func as _flash_attn_func, flash_attn_varlen_func as _flash_attn_varlen_func + from flash_attn.bert_padding import pad_input as _pad_input, index_first_axis as _index_first_axis, unpad_input as _unpad_input + flash_attn_func, flash_attn_varlen_func = _flash_attn_func, _flash_attn_varlen_func + pad_input, index_first_axis, unpad_input = _pad_input, _index_first_axis, _unpad_input + except ImportError: + raise ImportError("flash_attn is not installed.") + +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +# Copied from transformers.models.bart.modeling_bart._make_causal_mask +def _make_causal_mask( + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 +): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + +# Copied from transformers.models.bart.modeling_bart._expand_mask +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->MOSS2 +class MOSS2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + MOSS2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +# Copied from transformers.model.llama.modeling_llama.LlamaRotaryEmbedding with Llama->MOSS2 +class MOSS2RotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=torch.float32) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), + ) + + +# Copied from transformers.model.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->MOSS2 +class MOSS2LinearScalingRotaryEmbedding(MOSS2RotaryEmbedding): + """MOSS2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + t = t / self.scaling_factor + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + +# Copied from transformers.model.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->MOSS2 +class MOSS2DynamicNTKScalingRotaryEmbedding(MOSS2RotaryEmbedding): + """MOSS2RotaryEmbedding extended with Dynamic NTK scaling. + Credits to the Reddit users /u/bloc97 and /u/emozilla. + """ + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + +# Copied from transformers.model.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.model.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors.""" + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class MOSS2MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + # modified for TP + self.w1 = ColumnParallelLinearWithoutBias( + self.hidden_size, + self.intermediate_size, + bias=False, + gather_output=False, + init_method=lambda x: x, + ) + self.w3 = ColumnParallelLinearWithoutBias( + self.hidden_size, + self.intermediate_size, + bias=False, + gather_output=False, + init_method=lambda x: x, + ) + self.w2 = RowParallelLinearWithoutBias( + self.intermediate_size, + self.hidden_size, + bias=False, + input_is_parallel=True, + init_method=lambda x: x, + ) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.w2(self.act_fn(self.w1(x)) * self.w3(x)) + + return down_proj + + +# Copied from transformers.model.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +# Modified from transformers.model.llama.modeling_llama.LlamaAttention +class MOSS2Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: CollieConfig): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.is_causal = True + self.window_size = config.model_config.window_size + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + self.wqkv = ColumnParallelLinearWithoutBias( + config.hidden_size, + (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim, + bias=False, + gather_output=False, + init_method=lambda x: x, + ) + + self.wo = RowParallelLinearWithoutBias( + self.num_heads * self.head_dim, + config.hidden_size, + bias=False, + input_is_parallel=True, + init_method=lambda x: x, + ) + self._init_rope() + if self.config.attn_implementation == "flash_attention_2" or self.config.use_flash: + _import_flash_attn() + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = MOSS2RotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.config.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "dynamic": + self.rotary_emb = MOSS2DynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.config.rope_theta, + scaling_factor=scaling_factor, + ) + elif scaling_type == "linear": + self.rotary_emb = MOSS2LinearScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.config.rope_theta, + scaling_factor=scaling_factor, + ) + else: + raise ValueError("Currently we only support rotary embedding's type being 'dynamic' or 'linear'.") + return self.rotary_emb + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. " + "Please make sure use `attention_mask` instead.`" + ) + + bsz, q_len, _ = hidden_states.size() + + qkv_states = self.wqkv(hidden_states) + + qkv_states = rearrange( + qkv_states, + "b q (h gs d) -> b q h gs d", + gs=2 + self.num_key_value_groups, + d=self.head_dim, + ) + + query_states = qkv_states[..., : self.num_key_value_groups, :] + query_states = rearrange(query_states, "b q h gs d -> b q (h gs) d") + key_states = qkv_states[..., -2, :] + value_states = qkv_states[..., -1, :] + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + if self.config.pp_size > 1: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads/self.config.tp_size, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads/self.config.tp_size, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads/self.config.tp_size, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads/self.config.tp_size, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, int(self.hidden_size/self.config.tp_size)) + + attn_output = self.wo(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + +class MOSS2FlashAttention2(MOSS2Attention): + """ + MOSS2 flash attention module. This module inherits from `MOSS2Attention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # MOSS2FlashAttention2 attention does not support output_attentions + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. " + "Please make sure use `attention_mask` instead.`" + ) + + # overwrite attention_mask with padding_mask + attention_mask = kwargs.pop("padding_mask") + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + qkv_states = self.wqkv(hidden_states) + + qkv_states = rearrange( + qkv_states, + "b q (h gs d) -> b q h gs d", + gs=2 + self.num_key_value_groups, + d=self.head_dim, + ) + query_states = qkv_states[..., : self.num_key_value_groups, :] + query_states = rearrange(query_states, "b q h gs d -> b q (h gs) d") + key_states = qkv_states[..., -2, :] + value_states = qkv_states[..., -1, :] + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + if self.config.pp_size > 1: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + attn_output = self._flash_attention_forward( + query_states, key_states, value_states, attention_mask, q_len + ) + attn_output = attn_output.reshape(bsz, q_len, int(self.hidden_size/self.config.tp_size)).contiguous() + attn_output = self.wo(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + # Contains at least one padding token in the sequence + causal = self.is_causal and query_length != 1 + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._unpad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + def _unpad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, int(self.num_heads/self.config.tp_size), head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q.to(torch.int64), + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + +class MOSS2SparseAttention2(MOSS2FlashAttention2): + """ + MOSS2 sparse flash attention module. Only return attention output. Only support causal attention. + Attention bias, dropout and determistic backward are not supported. + This module inherits from `MOSS2FlashAttention2` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # MOSS2SparseAttention2 attention does not support output_attentions + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. " + "Please make sure use `attention_mask` instead.`" + ) + + # overwrite attention_mask with padding_mask + attention_mask = kwargs.pop("padding_mask") + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + qkv_states = self.wqkv(hidden_states) + + qkv_states = rearrange( + qkv_states, + "b q (h gs d) -> b q h gs d", + gs=2 + self.num_key_value_groups, + d=self.head_dim, + ) + query_states = qkv_states[..., : self.num_key_value_groups, :] + query_states = rearrange(query_states, "b q h gs d -> b q (h gs) d") + key_states = qkv_states[..., -2, :] + value_states = qkv_states[..., -1, :] + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + if self.config.pp_size > 1: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + attn_output = self._flash_attention_forward( + query_states, key_states, value_states, attention_mask, q_len + ) + attn_output = attn_output.reshape(bsz, q_len, int(self.hidden_size/self.config.tp_size)).contiguous() + attn_output = self.wo(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`int`, *optional*): + Attention dropout, not supported in sparse attention + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + # Contains at least one padding token in the sequence + causal = self.is_causal and query_length != 1 + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._unpad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = sparse_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_in_batch_q, + max_seqlen_in_batch_k, + None, # bias + True, # casual + softmax_scale, + self.window_size, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = sparse_func( + query_states, key_states, value_states, + None, # bias + True, # casual + softmax_scale, + self.window_size, + ) + + return attn_output + +MOSS2_ATTENTION_CLASSES = { + "eager": MOSS2Attention, + "flash_attention_2": MOSS2FlashAttention2, + "sparse": MOSS2SparseAttention2, +} + +# Modified from transformers.model.llama.modeling_llama.LlamaDecoderLayer +class MOSS2DecoderLayer(nn.Module): + def __init__(self, config: CollieConfig, layer_idx): + super().__init__() + self.hidden_size = config.hidden_size + + full_attn_list: Optional[List] = config.model_config.full_attention_layers + # self.attention = MOSS2_ATTENTION_CLASSES[config.attn_implementation](config=config) + if full_attn_list is not None \ + and layer_idx not in full_attn_list: + self.attention = MOSS2SparseAttention2(config=config) + else: + if config.attn_implementation == "flash_attention_2" or config.use_flash: + self.attention = MOSS2FlashAttention2(config=config) + else: + self.attention = MOSS2Attention(config=config) + + self.feed_forward = MOSS2MLP(config) + self.attention_norm = MOSS2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.ffn_norm = MOSS2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.config = config + # add for pp + self.idx = layer_idx + # 务必保持变量名一致 + self.use_cache = self.config.model_config.use_cache + self.hidden_states = None + self.output_attentions = False + + def _forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + # output_attentions: Optional[bool] = False, + # use_cache: Optional[bool] = False, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + hidden_states = self.attention_norm(hidden_states) + + if position_ids is None: # for pp + seq_length = hidden_states.shape[1] + past_key_values_length = 0 + if past_key_value is not None: + past_key_values_length = past_key_value[0][0].shape[1] + device = hidden_states.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.attention( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=self.output_attentions, + use_cache=self.use_cache, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.ffn_norm(hidden_states) + hidden_states = self.feed_forward(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states, present_key_value + + def forward(self, inputs: dict): + layer_past = inputs_to_kv_cache_for_layer(idx=self.idx, inputs=inputs) + + if self.config.checkpointing and self.training: + hidden_states, new_layer_past = torch.utils.checkpoint.checkpoint( + self._forward, + inputs["hidden_states"], + inputs.get("attention_mask", None), + inputs.get("position_ids", None), + layer_past, + ) + else: + hidden_states, new_layer_past = self._forward( + inputs["hidden_states"], + inputs.get("attention_mask", None), + inputs.get("position_ids", None), + layer_past + ) # **inputs + inputs["hidden_states"] = hidden_states + + inputs.update(kv_cache_to_inputs_for_layer(idx=self.idx, new_layer_past=new_layer_past)) + return inputs + + +MOSS2_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`MOSS2Config`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel with Llama->MOSS2 +@add_start_docstrings( + "The bare MOSS2 Model outputting raw hidden-states without any specific head on top.", + MOSS2_START_DOCSTRING, +) +class MOSS2PreTrainedModel(PreTrainedModel): + config_class = MOSS2Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["MOSS2DecoderLayer"] + _skip_keys_device_placement = "past_key_values" + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +MOSS2_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or + when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, decoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +# Modified from transformers.model.llama.modeling_llama.LlamaModel +@add_start_docstrings( + "The bare MOSS2 Model outputting raw hidden-states without any specific head on top.", + MOSS2_START_DOCSTRING, +) +class MOSS2Model(nn.Module): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MOSS2DecoderLayer`] + + Args: + config: CollieConfig + """ + + _auto_class = "AutoModel" + + def __init__(self, config: CollieConfig): + super().__init__() + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.config = config + + # self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.tok_embeddings = tensor_parallel.VocabParallelEmbedding( + config.vocab_size, config.hidden_size + ) + self.layers = nn.ModuleList([MOSS2DecoderLayer(config, i) for i in range(config.num_hidden_layers)]) + self.norm = MOSS2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + # self.post_init() + self.use_cache = self.config.model_config.use_cache + + def get_input_embeddings(self): + return self.tok_embeddings + + def set_input_embeddings(self, value): + self.tok_embeddings = value + + def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( + inputs_embeds.device + ) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + @add_start_docstrings_to_model_forward(MOSS2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + **kwargs, + ) -> Union[Tuple, BaseModelOutputWithPast]: + + if self.config.attn_implementation == "flash_attention_2" or self.config.use_flash: + _import_flash_attn() + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape[:2] + elif inputs_embeds is not None: + batch_size, seq_length = inputs_embeds.shape[:2] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + seq_length_with_past = seq_length + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0) + + if inputs_embeds is None: + inputs_embeds = self.tok_embeddings(input_ids) + + if self.config.attn_implementation == "flash_attention_2" or self.config.use_flash: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + else: + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device + ) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + + # embed positions + hidden_states = inputs_embeds + + use_cache = self.use_cache + if self.gradient_checkpointing and self.training: + if self.use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + inputs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "hidden_states": hidden_states, + "position_ids": position_ids, + } + + inputs.update(kv_cache_to_inputs_for_model(past_key_values)) + + all_hidden_states = () + for layer in self.layers: + all_hidden_states += (inputs["hidden_states"],) + inputs.update(layer(inputs)) + inputs["hidden_states"] = self.norm(inputs["hidden_states"]) + all_hidden_states += (inputs["hidden_states"],) + + past_key_values = inputs_to_kv_cache_for_model(self.config.num_hidden_layers, inputs) + + return BaseModelOutputWithPast( + last_hidden_state=inputs["hidden_states"], + hidden_states=all_hidden_states, + past_key_values=past_key_values, + ) + + @classmethod + def pipeline_layers(cls, config: CollieConfig): + """ + Get layers of pipeline. + + :return: list + """ + if isinstance(config, str): + config = CollieConfig.from_pretrained(config) + + if config.tie_word_embeddings: + embed_tokens = TiedLayerSpec( + "embed_tokens", + dict_as_params(input_keys="input_ids", output_keys="hidden_states"), + tensor_parallel.VocabParallelEmbedding, + config.vocab_size, + config.hidden_size, + ) + else: + embed_tokens = LayerSpec( + dict_as_params(input_keys="input_ids", output_keys="hidden_states"), + tensor_parallel.VocabParallelEmbedding, + config.vocab_size, + config.hidden_size, + ) + + layers = [ + LayerSpec(MOSS2DecoderLayer, config, i) for i in range(config.num_hidden_layers) + ] + norm = LayerSpec( + dict_as_params(input_keys="hidden_states", output_keys="hidden_states"), + MOSS2RMSNorm, + hidden_size=config.hidden_size, + eps=config.rms_norm_eps, + ) + + return [ + ("tok_embeddings", embed_tokens), + ("layers", layers), + ("norm", norm), + ] + +# Modified from transformers.model.llama.modeling_llama.LlamaForCausalLM +class MOSS2ForCausalLM(CollieModelForCausalLM): + _auto_class = "AutoModelForCausalLM" + + _tied_weights_keys = ["output.weight"] + base_model_prefix = "model" + + def __init__(self, config): + super().__init__(config) + self.model = MOSS2Model(config) + self.vocab_size = config.vocab_size + # self.output = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.output = ColumnParallelLinearWithoutBias( + self.collie_config.hidden_size, self.collie_config.vocab_size, bias=False + ) + # Initialize weights and apply final processing + # self.post_init() + # GenerationMixin 需要的额外参数 + self.config.is_decoder = True + if config.model_config.tie_word_embeddings: + self.lm_head.weight = self.embed_tokens.weight + self.main_input_name = "input_ids" + + def get_input_embeddings(self): + return self.model.tok_embeddings + + def set_input_embeddings(self, value): + self.model.tok_embeddings = value + + def get_output_embeddings(self): + return self.output + + def set_output_embeddings(self, new_embeddings): + self.output = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[torch.Tensor]] = None, + **kwargs, + ): + output = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + **kwargs, + ) + logits = self.output(output.last_hidden_state) + return CausalLMOutputWithPast( + loss=None, + logits=logits, + past_key_values=output.past_key_values, + hidden_states=output.hidden_states, + attentions=None, + ) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + ): + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + + def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = [], meta_instruction=""): + if tokenizer.add_bos_token: + prompt = "" + else: + prompt = tokenizer.bos_token + if meta_instruction: + prompt += f"""<|im_start|>system\n{meta_instruction}<|im_end|>\n""" + for record in history: + prompt += f"""<|im_start|>user\n{record[0]}<|im_end|>\n<|im_start|>assistant\n{record[1]}<|im_end|>\n""" + prompt += f"""<|im_start|>user\n{query}<|im_end|>\n<|im_start|>assistant\n""" + return tokenizer([prompt], return_tensors="pt") + + @torch.no_grad() + def chat( + self, + tokenizer, + query: str, + history: List[Tuple[str, str]] = [], + streamer: Optional[BaseStreamer] = None, + max_new_tokens: int = 1024, + do_sample: bool = True, + temperature: float = 0.8, + top_p: float = 0.8, + meta_instruction: str = "You are an AI assistant whose name is MOSS2. You can chat with me.", + **kwargs, + ): + inputs = self.build_inputs(tokenizer, query, history, meta_instruction) + inputs = {k: v.to(self.device) for k, v in inputs.items() if torch.is_tensor(v)} + # also add end-of-assistant token in eos token id to avoid unnecessary generation + eos_token_id = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids(["<|im_end|>"])[0]] + outputs = self.generate( + **inputs, + streamer=streamer, + max_new_tokens=max_new_tokens, + do_sample=do_sample, + temperature=temperature, + top_p=top_p, + eos_token_id=eos_token_id, + **kwargs, + ) + outputs = outputs[0].cpu().tolist()[len(inputs["input_ids"][0]) :] + response = tokenizer.decode(outputs, skip_special_tokens=True) + response = response.split("<|im_end|>")[0] + history = history + [(query, response)] + return response, history + + @torch.no_grad() + def stream_chat( + self, + tokenizer, + query: str, + history: List[Tuple[str, str]] = [], + max_new_tokens: int = 1024, + do_sample: bool = True, + temperature: float = 0.8, + top_p: float = 0.8, + **kwargs, + ): + """ + Return a generator in format: (response, history) + Eg. + ('你好,有什么可以帮助您的吗', [('你好', '你好,有什么可以帮助您的吗')]) + ('你好,有什么可以帮助您的吗?', [('你好', '你好,有什么可以帮助您的吗?')]) + """ + if BaseStreamer is None: + raise ModuleNotFoundError( + "The version of `transformers` is too low. Please make sure " + "that you have installed `transformers>=4.28.0`." + ) + + response_queue = queue.Queue(maxsize=20) + + class ChatStreamer(BaseStreamer): + def __init__(self, tokenizer) -> None: + super().__init__() + self.tokenizer = tokenizer + self.queue = response_queue + self.query = query + self.history = history + self.response = "" + self.cache = [] + self.received_inputs = False + self.queue.put((self.response, history + [(self.query, self.response)])) + + def put(self, value): + if len(value.shape) > 1 and value.shape[0] > 1: + raise ValueError("ChatStreamer only supports batch size 1") + elif len(value.shape) > 1: + value = value[0] + + if not self.received_inputs: + # The first received value is input_ids, ignore here + self.received_inputs = True + return + + self.cache.extend(value.tolist()) + token = self.tokenizer.decode(self.cache, skip_special_tokens=True) + if token.strip() != "<|im_end|>": + self.response = self.response + token + history = self.history + [(self.query, self.response)] + self.queue.put((self.response, history)) + self.cache = [] + else: + self.end() + + def end(self): + self.queue.put(None) + + def stream_producer(): + return self.chat( + tokenizer=tokenizer, + query=query, + streamer=ChatStreamer(tokenizer=tokenizer), + history=history, + max_new_tokens=max_new_tokens, + do_sample=do_sample, + temperature=temperature, + top_p=top_p, + **kwargs, + ) + + def consumer(): + producer = threading.Thread(target=stream_producer) + producer.start() + while True: + res = response_queue.get() + if res is None: + return + yield res + + return consumer() + + def clean_cache(self): + self._clean_hidden_states([*self.model.layers, self.output]) + self._set_use_cache(self.model.layers, False) + + def set_cache(self, use_cache): + self._set_use_cache(self.model.layers, use_cache) + + @classmethod + def pipeline_layers(cls, config: CollieConfig): + """ + Get layers of pipeline. + + :return: list + """ + if isinstance(config, str): + config = CollieConfig.from_pretrained(config) + + if config.tie_word_embeddings: + output = TiedLayerSpec( + "embed_tokens", + dict_as_params(input_keys="hidden_states", output_keys="logits"), + ColumnParallelLMHead, + config.hidden_size, + config.vocab_size, + bias=False, + ) + else: + output = LayerSpec( + dict_as_params(input_keys="hidden_states", output_keys="logits"), + ColumnParallelLMHead, + config.hidden_size, + config.vocab_size, + bias=False, + ) + + return [("model", MOSS2Model.pipeline_layers(config)), ("output", output)] + + @staticmethod + def load_parallel_state_dict( + path: str, + config: Union[CollieConfig, str], + process_exclusion: bool = False, + **kwargs, + ): + ... + + @staticmethod + def load_parallel_state_dict( + path: str, + config: Union[CollieConfig, str], + process_exclusion: bool = False, + protocol: str = "file", + **kwargs, + ): + """ + Load state_dict from ``path``. + + The format of pretrained model should be the same as that of + `huggingface`. + + :return: state_dict. Note that the state_dict should be processed + properly to match the current rank. + """ + if isinstance(config, str): + config = CollieConfig.from_pretrained(config) + io_driver = IODriver.from_protocol(protocol) + if not io_driver.exists(path): + raise FileNotFoundError(f"folder {path} not found.") + state_dict = OrderedDict() + weights = [] + parts = None + # 如果开启了进程互斥,那么每个进程都会显示进度条,否则只显示 RANK0 的 + hide_progress = not process_exclusion and int(os.environ.get("RANK", "0")) != 0 + if dist.is_initialized() and process_exclusion: + # 如果启动了进程互斥,则要进行 dist.get_world_size() 次循环 + rank_order = range(dist.get_world_size()) + else: + # 不开启只进行一次循环 + rank_order = range(1) + for rank in rank_order: + # 如果开启了进程互斥,那么只有对应 RANK 的能进入循环;不开启进程互斥的话就都可以进 + if int(os.environ.get("RANK", "0")) == rank or not process_exclusion: + # PP 分层的方法保存在了 os.environ["COLLIE_PP_PARTS"], 格式类似于 [0, 17, 35], 左闭右开 + if env.is_pipeline: + # 保存的是 json 格式 + parts = env.pipeline_parts + if hasattr(config, "num_key_value_heads"): + # llama2 (transformers >= 4.31.0) + num_key_value_heads = config.num_key_value_heads + else: + num_key_value_heads = config.num_attention_heads + head_dim = config.hidden_size // config.num_attention_heads + # 如果存在 pytorch_model.bin.index.json 文件的话,此时不同的 pp 进程可以按需加载自己需要的权重 + if ( + io_driver.exists(os.path.join(path, "pytorch_model.bin.index.json")) + and "COLLIE_PP_PARTS" in os.environ.keys() + ): + weight_map = json.loads( + io_driver.load( + os.path.join(path, "pytorch_model.bin.index.json"), mode="r" + ) + )["weight_map"] + # layers 表示自己需要的层 + layers = env.pipeline_layers_idx + # 筛选出形似 model.layers.0 这样的层。包含两个条件:1. 有数字的层;2. 数字加一要在 layers 里面(因为最开始还有个 embedding 占一层) + weights.extend( + [ + value + for key, value in weight_map.items() + if len(key.split(".")) > 2 + and key.split(".")[2].isdigit() + and (int(key.split(".")[2]) + 1) in layers + ] + ) + # 去重 + weights = list(set(weights)) + # 继续筛选,如果有 0 层,那么就要加载 embedding;如果有最后一层,那么就要加载 lm_head;如果有倒数第二层,那么就要加载 norm + if 0 in layers: + weights.append(weight_map["model.tok_embeddings.weight"]) + if max(parts) - 1 in layers: + weights.append(weight_map["output.weight"]) + if max(parts) - 2 in layers: + weights.append(weight_map["model.norm.weight"]) + else: + # 如果没有 pytorch_model.bin.index.json 文件的话,那么就加载所有的权重 + # 优先加载 safetensors 存储的权重 + weights = [ + weight + for weight in io_driver.list(path) + if weight.endswith(".safetensors") + ] + if len(weights) == 0: + # 如果没有 safetensors 文件,那么就加载 bin 文件 + weights = [ + weight + for weight in io_driver.list(path) + if weight.endswith(".bin") + ] + with progress( + weights, + desc="Loading state dict", + total=len(weights), + disable=hide_progress, + ) as pbar: + for weight in pbar: + part_state_dict = io_driver.load( + os.path.join(path, weight), mode="rb" + ) + state_dict.update(part_state_dict) + del part_state_dict + if parts is not None: + # 这一步是 pp 的复筛 + layers = env.pipeline_layers_idx + for key in list(state_dict.keys()): + if key.startswith("layers"): + layer = int(key.split(".")[1]) + if layer + 1 not in layers: + state_dict.pop(key) + if key.endswith("tok_embeddings.weight"): + if 0 not in layers: + state_dict.pop(key) + if key == "norm.weight": + if max(parts) - 2 not in layers: + state_dict.pop(key) + if key.endswith("output.weight"): + if max(parts) - 1 not in layers: + state_dict.pop(key) + # 根据用户配置的新的 tp size 进行分割 + for key in list(state_dict.keys()): + col_filter = [ + "wqkv.weight", + "w1.weight", + "w3.weight", + "tok_embeddings.weight", + "output.weight", + ] + col_split = any([key.endswith(filter) for filter in col_filter]) + + if col_split: + tensor = ( + list(torch.chunk(state_dict[key], config.tp_size, dim=0))[ + env.tp_rank + ] + .detach() + .clone() + ) + del state_dict[key] + if process_exclusion: + # CPU 内存回收(速度很慢) + gc.collect() + state_dict[key] = tensor + elif key.endswith("wo.weight") or key.endswith("w2.weight"): + tensor = ( + list(torch.chunk(state_dict[key], config.tp_size, dim=1))[ + env.tp_rank + ] + .detach() + .clone() + ) + del state_dict[key] + if process_exclusion: + # CPU 内存回收(速度很慢) + gc.collect() + state_dict[key] = tensor + if dist.is_initialized() and process_exclusion: + # 如果选择了进程互斥,那么本次循环中不需要加载权重的进程需等待 + dist.barrier() + return state_dict + + @staticmethod + def save_parallel_state_dict( + state_dict: dict, + path: str, + config: CollieConfig, + process_exclusion: bool = False, + **kwargs, + ): + ... + + @staticmethod + def save_parallel_state_dict( + state_dict: dict, + path: str, + config: CollieConfig, + process_exclusion: bool = False, + protocol: str = "file", + ): + """ + Save state_dict to ``path``. + + The format of saved state dict should be the same as that of + `huggingface`. + """ + io_driver = IODriver.from_protocol(protocol) + # gather to tp rank 0 + if dist.is_initialized() and process_exclusion: + # 如果启动了进程互斥,则要进行 pp_size 次循环 + rank_order = range(config.pp_size) + else: + # 不开启只进行一次循环 + rank_order = range(1) + dst = parallel_state.get_tensor_model_parallel_src_rank() + with progress( + rank_order, + desc="Saving model", + disable=int(os.environ.get("RANK", "0")) != 0, + ) as pbar: + for rank in pbar: + if env.dp_rank == 0 and (env.pp_rank == rank or not process_exclusion): + for key in sorted(list(state_dict.keys())): + tensor_list = None + if env.tp_rank == 0: + tensor_list = [ + torch.zeros_like(state_dict[key]) + .to(state_dict[key].dtype) + .cuda() + for _ in range(config.tp_size) + ] + dist.gather( + state_dict[key].cuda(), + dst=dst, + gather_list=tensor_list, + group=env.tp_group, + ) + if env.tp_rank == 0: + col_filter = [ + "wqkv.weight", + "w1.weight", + "w3.weight", + "tok_embeddings.weight", + "output.weight", + ] + col_split = any( + [key.endswith(filter) for filter in col_filter] + ) + + if col_split: + state_dict[key] = concat_tensor(tensor_list, dim=0) + + if process_exclusion: + # CPU 内存回收(速度很慢) + gc.collect() + + elif key.endswith("wo.weight") or key.endswith("w2.weight"): + state_dict[key] = concat_tensor(tensor_list, dim=1) + + if process_exclusion: + # CPU 内存回收(速度很慢) + gc.collect() + + if env.tp_rank == 0: + # Save gathered weights + if env.is_pipeline: + ckpt_name = f"pytorch_model-{env.pp_rank + 1:05d}-of-{config.pp_size:05d}.bin" + total_size = 0 + weight_map = {} + for name, weight in state_dict.items(): + weight_size = weight.numel() * dtype_byte_size( + weight.dtype + ) + weight_map[name] = ckpt_name + total_size += weight_size + index_dict = dict( + total_size=total_size, weight_map=weight_map + ) + index_dicts = [None for _ in range(env.pp_size)] + dist.gather_object( + index_dict, index_dicts if env.pp_rank == 0 else None, group=env.pp_group + ) + if env.pp_rank == 0: + total_size = 0 + weight_map = {} + for _index_dict in index_dicts: + total_size += _index_dict["total_size"] + weight_map.update(_index_dict["weight_map"]) + merged_dict = { + "metadata": {"total_size": total_size}, + "weight_map": weight_map, + } + io_driver.save( + json.dumps(merged_dict, indent=2, sort_keys=True) + + "\n", + os.path.join(path, "pytorch_model.bin.index.json"), + ) + + else: + ckpt_name = f"pytorch_model.bin" + ckpt_path = os.path.join(path, ckpt_name) + io_driver.save(state_dict, ckpt_path) + if dist.is_initialized() and process_exclusion: + dist.barrier() + if env.rank == 0: + config.save_pretrained(path, protocol=protocol) + dist.barrier() \ No newline at end of file diff --git a/collie/models/moss2alpha/sparse_kernel.py b/collie/models/moss2alpha/sparse_kernel.py new file mode 100644 index 0000000..89836a6 --- /dev/null +++ b/collie/models/moss2alpha/sparse_kernel.py @@ -0,0 +1,809 @@ +import math + +import torch +import triton +import triton.language as tl + +@triton.jit +def _fwd_kernel_one_row_block( + start_m, + Q: tl.const, + K: tl.const, + V: tl.const, + Out, + Lse, + softmax_scale: tl.constexpr, + stride_qm: tl.constexpr, + stride_kn: tl.constexpr, + stride_vn: tl.constexpr, + stride_om: tl.constexpr, + actual_seqlen_q, + actual_seqlen_k, + window_size_global, + SEQOFFSET, + EVEN_M: tl.constexpr, + EVEN_N: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, +): + # initialize pointer to m and l + m_i = tl.full([BLOCK_M], value=float("-inf"), dtype=tl.float32) + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) + # load q: it will stay in SRAM throughout + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, BLOCK_HEADDIM) + q_ptrs = (Q + (offs_m[:, None] * stride_qm + offs_d[None, :])) + l_ptrs = Lse + offs_m + out_ptrs = (Out + (offs_m[:, None] * stride_om + offs_d[None, :])) + if EVEN_M: + q = tl.load(q_ptrs, cache_modifier=".cg") + else: + q = tl.load(q_ptrs, mask=offs_m[:, None] < actual_seqlen_q, other=0.0, cache_modifier=".cg") + + log2e: tl.constexpr = 1.4426950408889634 + qk_scale = softmax_scale * log2e + # load k, v + offs_n_base = tl.arange(0, BLOCK_N) + k_ptrs = (K + (offs_d[:, None] + offs_n_base[None, :] * stride_kn)) # (BLOCK_HEADDIM, BLOCK_N) + v_ptrs = (V + (offs_n_base[:, None] * stride_vn + offs_d[None, :])) + global_end_n = tl.cdiv(window_size_global, BLOCK_N) * BLOCK_N + global_end_n = tl.multiple_of(global_end_n, BLOCK_N) + # loop of global part(could be 0) + for start_n in range(0, global_end_n, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + offs_n = start_n + offs_n_base + + # load k, v + if EVEN_N: + k = tl.load(k_ptrs + start_n * stride_kn, cache_modifier=".cg") + v = tl.load(v_ptrs + start_n * stride_vn, cache_modifier=".cg") + else: + mask_n = offs_n < actual_seqlen_k + k = tl.load(k_ptrs + start_n * stride_kn, mask=mask_n[None, :], other=0.0, cache_modifier=".cg") # (BLOCK_HEADDIM, BLOCK_N) + v = tl.load(v_ptrs + start_n * stride_vn, mask=mask_n[:, None], other=0.0, cache_modifier=".cg") + + # -- compute qk ---- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = tl.dot(q, k, qk, input_precision="tf32") + + # no need to mask for EVEN_N in `global` case + # if IS_GLOBAL: + qk += tl.where( # True will not mask + (offs_m[:, None] >= offs_n[None, :]) & + (((SEQOFFSET + offs_m)[:, None] <= offs_n[None, :]) | ((offs_n < window_size_global)[None, :])) + , 0, float("-inf")) + + # -- compute scaling constant --- + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) + alpha = tl.math.exp2((m_i - m_i_new) * qk_scale) + p = tl.math.exp2(qk * qk_scale - m_i_new[:, None] * qk_scale) + m_i = m_i_new + + # -- scale and update acc: acc *= alpha[:, None]-- + acc *= alpha[:, None] + acc = tl.dot(p.to(q.dtype), v, acc) + + # -- update m_i and l_i -- + l_i = tl.fma(l_i, alpha, tl.sum(p, 1)) + + local_start_n = tl.maximum(((start_m * BLOCK_M + SEQOFFSET) // BLOCK_N) * BLOCK_N, global_end_n) + local_start_n = tl.multiple_of(local_start_n, BLOCK_N) + end_n = tl.minimum((start_m + 1) * BLOCK_M, actual_seqlen_k) + for start_n in range(local_start_n, end_n, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + offs_n = start_n + offs_n_base + + # load k, v + if EVEN_N: + k = tl.load(k_ptrs + start_n * stride_kn, cache_modifier=".cg") + v = tl.load(v_ptrs + start_n * stride_vn, cache_modifier=".cg") + else: + mask_n = offs_n < actual_seqlen_k + k = tl.load(k_ptrs + start_n * stride_kn, mask=mask_n[None, :], other=0.0, cache_modifier=".cg") # (BLOCK_HEADDIM, BLOCK_N) + v = tl.load(v_ptrs + start_n * stride_vn, mask=mask_n[:, None], other=0.0, cache_modifier=".cg") + + # -- compute qk ---- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = tl.dot(q, k, qk, input_precision="tf32") + + if not EVEN_N: # Need to mask out otherwise the softmax is wrong + qk += tl.where(offs_n[None, :] < actual_seqlen_k, 0, float("-inf")) + # if IS_GLOBAL: + qk += tl.where( # True will not mask + (offs_m[:, None] >= offs_n[None, :]) & ### $ + (((SEQOFFSET + offs_m)[:, None] <= offs_n[None, :])) # `local` part so we need not to (start_n + offs_n < window_size_global) + , 0, float("-inf")) + + # -- compute scaling constant --- + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) + alpha = tl.math.exp2((m_i - m_i_new) * qk_scale) + p = tl.math.exp2(qk * qk_scale - m_i_new[:, None] * qk_scale) + # -- scale and update acc: acc *= alpha[:, None]-- + acc *= alpha[:, None] + acc = tl.dot(p.to(q.dtype), v, acc, input_precision="tf32") + + # -- update m_i and l_i -- + l_i = tl.fma(l_i, alpha, tl.sum(p, 1)) + m_i = m_i_new + + acc = acc * (1.0 / l_i[:, None]) # reduce the number of division + l = tl.fma(m_i, softmax_scale, tl.log(l_i)) # log(normalizer) + # initialize pointers to output + if EVEN_M: + tl.store(l_ptrs, l, cache_modifier=".cs") # .cs is for data accessed once + tl.store(out_ptrs, acc, cache_modifier=".cs") + else: + mask_m = offs_m < actual_seqlen_q + tl.store(l_ptrs, l, mask=mask_m, cache_modifier=".cs") + tl.store(out_ptrs, acc, mask=mask_m[:, None], cache_modifier=".cs") + +@triton.heuristics( + { + "BLOCK_M": lambda args: 128, + "BLOCK_N": lambda args: 128, # 64 or 128 + "EVEN_M": lambda args: args["actual_seqlen_q"] & 127 == 0, # % 128 == 0 + "EVEN_N": lambda args: args["actual_seqlen_k"] & 127 == 0, # % 128 == 0 + "num_warps": lambda args: 8, + "num_stages": lambda args: 3, # for faster forward pass + } +) +@triton.jit +def _fwd_kernel( + Q: tl.const, + K: tl.const, + V: tl.const, + Out, + Lse, + softmax_scale: tl.constexpr, + stride_qb: tl.constexpr, + stride_qh: tl.constexpr, + stride_qm: tl.constexpr, + stride_kb: tl.constexpr, + stride_kh: tl.constexpr, + stride_kn: tl.constexpr, + stride_vb: tl.constexpr, + stride_vh: tl.constexpr, + stride_vn: tl.constexpr, + stride_ob: tl.constexpr, + stride_oh: tl.constexpr, + stride_om: tl.constexpr, + actual_seqlen_q: tl.constexpr, + actual_seqlen_k: tl.constexpr, + max_seqlen_q_rounded, + nheads: tl.constexpr, + nheads_k: tl.constexpr, + window_size_global: tl.constexpr, + SEQOFFSET: tl.constexpr, + d_rounded: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + EVEN_M: tl.constexpr, + EVEN_N: tl.constexpr, +): + start_m = tl.program_id(0) + off_b = tl.program_id(2) + # invalid grid block for this seq + if actual_seqlen_q <= start_m * BLOCK_M: + return + + off_h = tl.program_id(1) + Q += ( off_b * stride_qb + off_h * stride_qh ) + Out += ( off_b * stride_ob + off_h * stride_oh ) + Lse += (off_b * nheads + off_h) * max_seqlen_q_rounded + + off_h_kv = off_h * nheads_k // nheads + K += ( off_b * stride_kb + off_h_kv * stride_kh ) + V += ( off_b * stride_vb + off_h_kv * stride_vh ) + + _fwd_kernel_one_row_block( + start_m, + Q, + K, + V, + Out, + Lse, + softmax_scale, + stride_qm, + stride_kn, + stride_vn, + stride_om, + actual_seqlen_q, + actual_seqlen_k, + window_size_global, + SEQOFFSET, + EVEN_M=EVEN_M, + EVEN_N=EVEN_N, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_HEADDIM=d_rounded, + ) + +@triton.heuristics( + { + "BLOCK_M": lambda args: 128, + "EVEN_M": lambda args: args["actual_seqlen_q"] & 127 == 0, # % 128 == 0 + # "num_warps": lambda args: 8, + # "num_stages": lambda args: 1, + } +) +@triton.jit +def _bwd_preprocess_do_o_dot( + Out: tl.const, + DO: tl.const, + Delta, + stride_ob: tl.constexpr, + stride_oh: tl.constexpr, + stride_om: tl.constexpr, + stride_dob: tl.constexpr, + stride_doh: tl.constexpr, + stride_dom: tl.constexpr, + actual_seqlen_q: tl.constexpr, + max_seqlen_q_rounded, + BLOCK_HEADDIM: tl.constexpr, + BLOCK_M: tl.constexpr, + EVEN_M: tl.constexpr, +): + start_m = tl.program_id(0) + off_h = tl.program_id(1) + off_b = tl.program_id(2) + if actual_seqlen_q <= start_m * BLOCK_M: + return + + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, BLOCK_HEADDIM) + # load + o_ptrs = (Out + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :]) + do_ptrs = ( + DO + off_b * stride_dob + off_h * stride_doh + offs_m[:, None] * stride_dom + offs_d[None, :] + ) + EVEN_M = actual_seqlen_q % BLOCK_M == 0 + + if EVEN_M: + o = tl.load(o_ptrs, cache_modifier=".cg").to(tl.float32) + do = tl.load(do_ptrs, cache_modifier=".cg").to(tl.float32) + else: + mask_m = (offs_m < actual_seqlen_q)[:, None] + o = tl.load(o_ptrs, mask=mask_m, other=0.0, cache_modifier=".cg").to(tl.float32) + do = tl.load(do_ptrs, mask=mask_m, other=0.0, cache_modifier=".cg").to(tl.float32) + + delta = tl.sum(o * do, axis=1) + # off_b * tl.num_programs(1) + off_h == off_b * nheads + off_h + delta_ptrs = Delta + (off_b * tl.num_programs(1) + off_h) * max_seqlen_q_rounded + offs_m + # write-back + if EVEN_M: + tl.store(delta_ptrs, delta, cache_modifier=".cs") + else: + tl.store(delta_ptrs, delta, mask=(offs_m < actual_seqlen_q), cache_modifier=".cs") + +@triton.heuristics( + { + "BLOCK_M": lambda args: 64, + "BLOCK_N": lambda args: 64, # Reducing block sizes + "EVEN_M": lambda args: args["actual_seqlen_q"] & 63 == 0, # % 64 == 0 + "EVEN_N": lambda args: args["actual_seqlen_k"] & 63 == 0, # % 64 == 0 + "num_warps": lambda args: 4, + "num_stages": lambda args: 2, + } +) +@triton.jit +def _bwd_dk_dv_kernel( + Q: tl.const, + K: tl.const, + V: tl.const, + DO: tl.const, + DK, + DV, + LSE: tl.const, + D: tl.const, + softmax_scale: tl.constexpr, + stride_qb: tl.constexpr, + stride_qh: tl.constexpr, + stride_qm: tl.constexpr, + stride_kb: tl.constexpr, + stride_kh: tl.constexpr, + stride_kn: tl.constexpr, + stride_vb: tl.constexpr, + stride_vh: tl.constexpr, + stride_vn: tl.constexpr, + stride_dob: tl.constexpr, + stride_doh: tl.constexpr, + stride_dom: tl.constexpr, + stride_dkb: tl.constexpr, + stride_dkh: tl.constexpr, + stride_dkn: tl.constexpr, + stride_dvb: tl.constexpr, + stride_dvh: tl.constexpr, + stride_dvn: tl.constexpr, + nheads: tl.constexpr, + nheads_k: tl.constexpr, + actual_seqlen_q: tl.const, + actual_seqlen_k: tl.const, + max_seqlen_q_rounded, + window_size_global: tl.constexpr, + SEQOFFSET: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + EVEN_M: tl.constexpr, + EVEN_N: tl.constexpr, +): + off_h = tl.program_id(1) + start_n = tl.program_id(0) + if actual_seqlen_k <= start_n * BLOCK_N: + return + + log2e: tl.constexpr = 1.4426950408889634 + qk_scale = softmax_scale * log2e + + off_b = tl.program_id(2) + # offset pointers for batch/head + Q += off_b * stride_qb + off_h * stride_qh + off_h_kv = off_h * nheads_k // nheads + K += off_b * stride_kb + off_h_kv * stride_kh + V += off_b * stride_vb + off_h_kv * stride_vh + DO += off_b * stride_dob + off_h * stride_doh + DK += off_b * stride_dkb + off_h * stride_dkh + DV += off_b * stride_dvb + off_h * stride_dvh + # pointer to row-wise quantities in value-like data + off_hb = off_b * nheads + off_h + D += off_hb * max_seqlen_q_rounded + LSE += off_hb * max_seqlen_q_rounded + + begin_m = ((start_n * BLOCK_N) // BLOCK_M) * BLOCK_M + # initialize row/col offsets + offs_qm = begin_m + tl.arange(0, BLOCK_M) + offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_n_slice = offs_n[None, :] + offs_m_base = tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, BLOCK_HEADDIM) + + # k & v transposed here + q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_d[None, :]) + k_ptrs = K + (offs_n_slice * stride_kn + offs_d[:, None]) # transposed here + v_ptrs = V + (offs_n_slice * stride_vn + offs_d[:, None]) # transposed here + do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :]) + dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :]) + dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :]) + # initialize dv and dk + dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) + dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) + + # load k , v + if EVEN_N: + k = tl.load(k_ptrs, cache_modifier=".cg") + v = tl.load(v_ptrs, cache_modifier=".cg") + else: + mask_n = offs_n_slice < actual_seqlen_k + k = tl.load(k_ptrs, mask=mask_n, other=0.0, cache_modifier=".cg") + v = tl.load(v_ptrs, mask=mask_n, other=0.0, cache_modifier=".cg") + + # loop over rows + num_block_m = tl.cdiv(actual_seqlen_q, BLOCK_M) + end_m = num_block_m * BLOCK_M # $ + global_end_n = tl.cdiv(window_size_global, BLOCK_N) - 1 + local_start_n = tl.cdiv(actual_seqlen_q + SEQOFFSET, BLOCK_N) - 1 # actual_seqlen_k-window_size_left + if (start_n > global_end_n) & (start_n < local_start_n): + end_m = tl.cdiv((start_n + 1) * BLOCK_N - SEQOFFSET, BLOCK_M) * BLOCK_M # $ + for start_m in range(begin_m, end_m, BLOCK_M): + start_m = tl.multiple_of(start_m, BLOCK_M) + offs_m = start_m + offs_m_base + # load q, l, do on-chip + if EVEN_M: + q = tl.load(q_ptrs, cache_modifier=".cg") + do = tl.load(do_ptrs, cache_modifier=".cg") + l = tl.load(LSE + offs_m, cache_modifier=".cg") + Di = tl.load(D + offs_m, cache_modifier=".cg") + else: + mask_m = offs_m < actual_seqlen_q + q = tl.load(q_ptrs, mask=mask_m[:, None], other=0.0, cache_modifier=".cg") + do = tl.load(do_ptrs, mask=mask_m[:, None], other=0.0, cache_modifier=".cg") + l = tl.load(LSE + offs_m, mask=mask_m, cache_modifier=".cg") + Di = tl.load(D + offs_m, mask=mask_m, cache_modifier=".cg") + q_ptrs += BLOCK_M * stride_qm + do_ptrs += BLOCK_M * stride_dom + + # recompute p = softmax(qk, dim=-1).T + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = tl.dot(q, k, qk, input_precision="tf32") + + # Trying to combine the two masks seem to make the result wrong + if not EVEN_N: # Need to mask out otherwise the softmax is wrong + qk += tl.where(offs_n_slice < actual_seqlen_k, 0, float("-inf")) + qk += tl.where( # True will not mask + (offs_m[:, None] >= offs_n_slice) & + (((SEQOFFSET + offs_m)[:, None] <= offs_n_slice) | (offs_n_slice < window_size_global)) + , 0, float("-inf")) + + p = tl.math.exp2(qk * qk_scale - l[:, None] * log2e) + dv = tl.dot(tl.trans(p.to(do.dtype)), do, dv, input_precision="tf32") + # compute dp = dot(v, do) + dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + dp = tl.dot(do, v, dp, input_precision="tf32") + + ds = (p * (dp - Di[:, None])).to(q.dtype) + # compute dk = dot(ds.T, q) + dk = tl.dot(tl.trans(ds), q, dk, input_precision="tf32") + + dk *= softmax_scale + # write-back + if EVEN_N: + tl.store(dv_ptrs, dv.to(k.dtype), cache_modifier=".cs") + tl.store(dk_ptrs, dk.to(k.dtype), cache_modifier=".cs") + else: + mask_n = offs_n < actual_seqlen_k + tl.store(dk_ptrs, dk.to(k.dtype), mask=mask_n[:, None], cache_modifier=".cs") + tl.store(dv_ptrs, dv.to(k.dtype), mask=mask_n[:, None], cache_modifier=".cs") + + +@triton.heuristics( + { + "BLOCK_M": lambda args: 64, + "BLOCK_N": lambda args: 64, # Reducing block sizes + "EVEN_M": lambda args: args["actual_seqlen_q"] & 63 == 0, # % 64 == 0 + "EVEN_N": lambda args: args["actual_seqlen_k"] & 63 == 0, # % 64 == 0 + "num_warps": lambda args: 4, + "num_stages": lambda args: 2, + } +) +@triton.jit +def _bwd_dq_kernel( + Q: tl.const, + K: tl.const, + V: tl.const, + DO: tl.const, + DQ, + LSE: tl.const, + D: tl.const, + softmax_scale: tl.constexpr, + stride_qb: tl.constexpr, + stride_qh: tl.constexpr, + stride_qm: tl.constexpr, + stride_kb: tl.constexpr, + stride_kh: tl.constexpr, + stride_kn: tl.constexpr, + stride_vb: tl.constexpr, + stride_vh: tl.constexpr, + stride_vn: tl.constexpr, + stride_dob: tl.constexpr, + stride_doh: tl.constexpr, + stride_dom: tl.constexpr, + stride_dqb: tl.constexpr, + stride_dqh: tl.constexpr, + stride_dqm: tl.constexpr, + nheads: tl.constexpr, + nheads_k: tl.constexpr, + actual_seqlen_q: tl.const, + actual_seqlen_k: tl.const, + max_seqlen_q_rounded, + window_size_global: tl.constexpr, + SEQOFFSET: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + EVEN_M: tl.constexpr, + EVEN_N: tl.constexpr, +): + off_h = tl.program_id(1) + off_b = tl.program_id(2) + start_m = tl.program_id(0) + # invalid grid block for this seq + if actual_seqlen_q <= start_m * BLOCK_N: + return + log2e: tl.constexpr = 1.4426950408889634 + qk_scale = softmax_scale * log2e + + # offset pointers for batch/head + Q += off_b * stride_qb + off_h * stride_qh + off_h_kv = off_h * nheads_k // nheads + K += off_b * stride_kb + off_h_kv * stride_kh + V += off_b * stride_vb + off_h_kv * stride_vh + DO += off_b * stride_dob + off_h * stride_doh + DQ += off_b * stride_dqb + off_h * stride_dqh + # pointer to row-wise quantities in value-like data + off_hb = off_b * nheads + off_h + D += off_hb * max_seqlen_q_rounded + LSE += off_hb * max_seqlen_q_rounded + + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, BLOCK_HEADDIM) + offs_m_slice = offs_m[:, None] + q_ptrs = (Q + (offs_m_slice * stride_qm + offs_d[None, :])) + do_ptrs = (DO + (offs_m_slice * stride_dom + offs_d[None, :])) + dq_ptrs = DQ + (offs_m_slice * stride_dqm + offs_d[None, :]) + + # load q & do: it will stay in SRAM throughout + # load q & do & load l & delta: it will stay in SRAM throughout + if EVEN_M: + q = tl.load(q_ptrs, cache_modifier=".cg") + do = tl.load(do_ptrs, cache_modifier=".cg") + l = tl.load(LSE + offs_m, cache_modifier=".cg") + Di = tl.load(D + offs_m, cache_modifier=".cg") + else: + mask_m = offs_m < actual_seqlen_q + l = tl.load(LSE + offs_m, mask=mask_m, cache_modifier=".cg") + Di = tl.load(D + offs_m, mask=mask_m, cache_modifier=".cg") + q = tl.load(q_ptrs, mask=mask_m[:, None], cache_modifier=".cg") + do = tl.load(do_ptrs, mask=mask_m[:, None], cache_modifier=".cg") + l_slice_scale = l[:, None] * log2e + + # dq init + dq = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) + + # load k, v. k, v transposed here + offs_n_base_slice = (tl.arange(0, BLOCK_N))[None, :] + k_ptrs = (K + (offs_d[:, None] + offs_n_base_slice * stride_kn)) # (BLOCK_HEADDIM, BLOCK_N) + v_ptrs = (V + (offs_d[:, None] + offs_n_base_slice * stride_vn)) + global_end_n = tl.cdiv(window_size_global, BLOCK_N) * BLOCK_N + global_end_n = tl.multiple_of(global_end_n, BLOCK_N) + # loop of global part(could be 0) + for start_n in range(0, global_end_n, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + offs_n_slice = start_n + offs_n_base_slice + + # load k, v + if EVEN_N: + k = tl.load(k_ptrs + start_n * stride_kn, cache_modifier=".cg") + v = tl.load(v_ptrs + start_n * stride_vn, cache_modifier=".cg") + else: + mask_n = offs_n_slice < actual_seqlen_k + k = tl.load(k_ptrs + start_n * stride_kn, mask=mask_n, other=0.0, cache_modifier=".cg") # (BLOCK_HEADDIM, BLOCK_N) + v = tl.load(v_ptrs + start_n * stride_vn, mask=mask_n, other=0.0, cache_modifier=".cg") + + # -- compute qk ---- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = tl.dot(q, k, qk, input_precision="tf32") + + # no need to mask for EVEN_N in `global` case + # if IS_GLOBAL: + qk += tl.where( # True will not mask + (offs_m_slice >= offs_n_slice) & + ((SEQOFFSET + offs_m_slice <= offs_n_slice) | (offs_n_slice < window_size_global)) + , 0, float("-inf")) + + # -- compute p --- + p = tl.math.exp2(qk * qk_scale - l_slice_scale) + # compute dq = dot(p, do) + dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + dp = tl.dot(do.to(q.dtype), v, dp, input_precision="tf32") + + ds = (p * (dp - Di[:, None])).to(q.dtype) + dq = tl.dot(ds, tl.trans(k), dq) + + local_start_n = tl.maximum(((start_m * BLOCK_M + SEQOFFSET) // BLOCK_N) * BLOCK_N, global_end_n) + local_start_n = tl.multiple_of(local_start_n, BLOCK_N) + end_n = tl.minimum((start_m + 1) * BLOCK_M, actual_seqlen_k) + for start_n in range(local_start_n, end_n, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + offs_n_slice = start_n + offs_n_base_slice + + # load k, v + if EVEN_N: + k = tl.load(k_ptrs + start_n * stride_kn, cache_modifier=".cg") + v = tl.load(v_ptrs + start_n * stride_vn, cache_modifier=".cg") + else: + mask_n = offs_n_slice < actual_seqlen_k + k = tl.load(k_ptrs + start_n * stride_kn, mask=mask_n, other=0.0, cache_modifier=".cg") # (BLOCK_HEADDIM, BLOCK_N) + v = tl.load(v_ptrs + start_n * stride_vn, mask=mask_n, other=0.0, cache_modifier=".cg") + + # -- compute qk ---- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = tl.dot(q, k, qk, input_precision="tf32") + + if not EVEN_N: # Need to mask out otherwise the softmax is wrong + qk += tl.where(offs_n_slice < actual_seqlen_k, 0, float("-inf")) + # if IS_GLOBAL: + qk += tl.where( # True will not mask + (offs_m_slice >= offs_n_slice) & ### $ + ((SEQOFFSET + offs_m_slice <= offs_n_slice)) # `local` part so we need not to (start_n + offs_n < window_size_global) + , 0, float("-inf")) + + # -- compute p --- + p = tl.math.exp2(qk * qk_scale - l_slice_scale) + # compute dq = dot(p, do) + dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + dp = tl.dot(do.to(q.dtype), v, dp, input_precision="tf32") + + ds = (p * (dp - Di[:, None])).to(q.dtype) + dq = tl.dot(ds, tl.trans(k), dq) + + dq *= softmax_scale + if EVEN_M: + tl.store(dq_ptrs, dq.to(q.dtype), cache_modifier=".cs") + else: + tl.store(dq_ptrs, dq.to(q.dtype), mask=offs_m_slice < actual_seqlen_q, cache_modifier=".cs") + +class FlashAttnFunc(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, bias=None, causal=False, softmax_scale=None, window_size=(-1,-1,-1), + ): + """ + q: ((b s), nheads, headdim) + k, v: ((b s) nheads, headdim) + bias: deleted. + """ + # Make sure that the last dimension is contiguous + batch, seqlen_q, num_heads, d = q.shape + _, seqlen_k, num_heads_k, dk = k.shape + assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1 + assert num_heads % num_heads_k == 0, "num_heads must be divisible by num_heads_k" + assert d == dk and dk == v.shape[-1] and num_heads_k == v.shape[-2], "num_heads and head dimensions must match" + assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type" + assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16" + assert q.is_cuda and k.is_cuda and v.is_cuda, "All tensors must be sent to GPU" + # I choose not to support the case where `d` is not a power of 2, + # which is sufficient for current LLM and simplifies the mask for load/store. + assert d in {16, 32, 64, 128}, "Only support d in {16, 32, 64, 128}" + + # In training, `window_size_left + window_size_global` is never larger than `max_seqlen_k`. + # in training, causal is always true. + # causal=True + softmax_scale = softmax_scale or 1.0 / math.sqrt(d) + ctx.softmax_scale = softmax_scale + + window_size_left = window_size[1] if window_size[1] >= 0 and window_size[1] <= seqlen_k else seqlen_k + window_size_global = window_size[0] if window_size[0] > 0 and window_size[0] < seqlen_k else 0 + SEQOFFSET = seqlen_k - seqlen_q - window_size_left + + ctx.window_size_global, ctx.window_size_left, ctx.SEQOFFSET = window_size_global, window_size_left, SEQOFFSET + + seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 + lse = torch.empty((batch, num_heads, seqlen_q_rounded), device=q.device, dtype=torch.float32) + o = torch.empty_like(q) + # using 3d grid to avoid div & rem + grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), num_heads, batch) + _fwd_kernel[grid]( + q, + k, + v, + o, + lse, + softmax_scale, + q.stride(0), + q.stride(2), + q.stride(1), + k.stride(0), + k.stride(2), + k.stride(1), + v.stride(0), + v.stride(2), + v.stride(1), + o.stride(0), + o.stride(2), + o.stride(1), + seqlen_q, + seqlen_k, + seqlen_q_rounded, + num_heads, + num_heads_k, + window_size_global, # window_size_global + SEQOFFSET, # window_size_left + # IS_CAUSAL=causal, BLOCK_HEADDIM=d, + d, # d_rounded (BLOCK_HEADDIM) actually + # BLOCK_M=128, + # BLOCK_N=64, + # num_warps=num_warps, + # num_stages=1, + ) + + ctx.save_for_backward(q, k, v, o, lse) + return o + + @staticmethod + def backward(ctx, do): + q, k, v, o, lse = ctx.saved_tensors + # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd + # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version. + with torch.inference_mode(): + if do.stride(-1) != 1: do = do.contiguous() + batch, seqlen_q, num_heads, d = q.shape + _, seqlen_k, num_heads_k, _ = k.shape + kv_group_size = num_heads // num_heads_k + delta = torch.empty_like(lse) + max_seqlen_q_rounded = lse.shape[-1] + + grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), num_heads, batch) + _bwd_preprocess_do_o_dot[grid]( + o, + do, + delta, + o.stride(0), + o.stride(2), + o.stride(1), + do.stride(0), + do.stride(2), + do.stride(1), + seqlen_q, + max_seqlen_q_rounded, + d, + ) + dk_expanded = torch.empty((batch, seqlen_k, num_heads, d), dtype=do.dtype, device=do.device) + dv_expanded = torch.empty((batch, seqlen_k, num_heads, d), dtype=do.dtype, device=do.device) + grid = lambda META: (triton.cdiv(seqlen_k, META["BLOCK_N"]), num_heads, batch) + _bwd_dk_dv_kernel[grid]( + q, + k, + v, + do, + dk_expanded, + dv_expanded, + lse, + delta, + ctx.softmax_scale, + q.stride(0), + q.stride(2), + q.stride(1), + k.stride(0), + k.stride(2), + k.stride(1), + v.stride(0), + v.stride(2), + v.stride(1), + do.stride(0), + do.stride(2), + do.stride(1), + dk_expanded.stride(0), + dk_expanded.stride(2), + dk_expanded.stride(1), + dv_expanded.stride(0), + dv_expanded.stride(2), + dv_expanded.stride(1), + num_heads, + num_heads_k, + seqlen_q, + seqlen_k, + max_seqlen_q_rounded, + ctx.window_size_global, + ctx.SEQOFFSET, + d, #BLOCK_HEADDIM=d, + # BLOCK_M=128, BLOCK_N=128, + # num_warps=8, + # num_stages=1, + ) + dk = dk_expanded.reshape(batch, seqlen_k, num_heads_k, kv_group_size, d).sum(dim=3, keepdim=False) + dv = dv_expanded.reshape(batch, seqlen_k, num_heads_k, kv_group_size, d).sum(dim=3, keepdim=False) + dq = torch.zeros_like(q) + grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), num_heads, batch) + _bwd_dq_kernel[grid]( + q, + k, + v, + do, + dq, + lse, + delta, + ctx.softmax_scale, + q.stride(0), + q.stride(2), + q.stride(1), + k.stride(0), + k.stride(2), + k.stride(1), + v.stride(0), + v.stride(2), + v.stride(1), + do.stride(0), + do.stride(2), + do.stride(1), + dq.stride(0), + dq.stride(2), + dq.stride(1), + num_heads, + num_heads_k, + seqlen_q, + seqlen_k, + max_seqlen_q_rounded, + ctx.window_size_global, + ctx.SEQOFFSET, + d, #BLOCK_HEADDIM=d, + # BLOCK_M=128, BLOCK_N=128, + # num_warps=8, + # num_stages=1, + ) + + # This is how many gradients you have to return as many arguments forward + return dq, dk, dv, None, None, None, None, None, None, None, None + +flash_attn_func = FlashAttnFunc.apply diff --git a/collie/models/moss2alpha/sparse_varlen_kernel.py b/collie/models/moss2alpha/sparse_varlen_kernel.py new file mode 100644 index 0000000..055df83 --- /dev/null +++ b/collie/models/moss2alpha/sparse_varlen_kernel.py @@ -0,0 +1,996 @@ +import math + +import torch +import triton +import triton.language as tl + +@triton.jit +def _fwd_kernel_one_row_block( + start_m, + Q: tl.const, + K: tl.const, + V: tl.const, + Out, + Lse, + softmax_scale: tl.constexpr, + stride_qb: tl.constexpr, + stride_kb: tl.constexpr, + stride_vb: tl.constexpr, + stride_ob: tl.constexpr, + actual_seqlen_q, + actual_seqlen_k, + window_size_global, + SEQOFFSET, + EVEN_M, + EVEN_N, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, +): + # initialize pointer to m and l + m_i = tl.full([BLOCK_M], value=float("-inf"), dtype=tl.float32) + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) + # load q: it will stay in SRAM throughout + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, BLOCK_HEADDIM) + q_ptrs = (Q + (offs_m[:, None] * stride_qb + offs_d[None, :])) + l_ptrs = Lse + offs_m + out_ptrs = (Out + (offs_m[:, None] * stride_ob + offs_d[None, :])) + if EVEN_M: + q = tl.load(q_ptrs, cache_modifier=".cg") + else: + q = tl.load(q_ptrs, mask=offs_m[:, None] < actual_seqlen_q, other=0.0, cache_modifier=".cg") + + log2e: tl.constexpr = 1.4426950408889634 + qk_scale = softmax_scale * log2e + # load k, v + offs_n_base = tl.arange(0, BLOCK_N) + k_ptrs = (K + (offs_d[:, None] + offs_n_base[None, :] * stride_kb)) # (BLOCK_HEADDIM, BLOCK_N) + v_ptrs = (V + (offs_n_base[:, None] * stride_vb + offs_d[None, :])) + global_end_n = tl.cdiv(window_size_global, BLOCK_N) * BLOCK_N + global_end_n = tl.multiple_of(global_end_n, BLOCK_N) + # loop of global part(could be 0) + for start_n in range(0, global_end_n, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + offs_n = start_n + offs_n_base + + # load k, v + if EVEN_N: + k = tl.load(k_ptrs + start_n * stride_kb, cache_modifier=".cg") + v = tl.load(v_ptrs + start_n * stride_vb, cache_modifier=".cg") + else: + mask_n = offs_n < actual_seqlen_k + k = tl.load(k_ptrs + start_n * stride_kb, mask=mask_n[None, :], other=0.0, cache_modifier=".cg") # (BLOCK_HEADDIM, BLOCK_N) + v = tl.load(v_ptrs + start_n * stride_vb, mask=mask_n[:, None], other=0.0, cache_modifier=".cg") + + # -- compute qk ---- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = tl.dot(q, k, qk, input_precision="tf32") + + # no need to mask for EVEN_N in `global` case + # if IS_GLOBAL: + qk += tl.where( # True will not mask + (offs_m[:, None] >= offs_n[None, :]) & + (((SEQOFFSET + offs_m)[:, None] <= offs_n[None, :]) | ((offs_n < window_size_global)[None, :])) + , 0, float("-inf")) + + # -- compute scaling constant --- + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) + alpha = tl.math.exp2((m_i - m_i_new) * qk_scale) + p = tl.math.exp2(qk * qk_scale - m_i_new[:, None] * qk_scale) + m_i = m_i_new + + # -- scale and update acc: acc *= alpha[:, None]-- + acc *= alpha[:, None] + acc = tl.dot(p.to(q.dtype), v, acc) + + # -- update m_i and l_i -- + l_i = tl.fma(l_i, alpha, tl.sum(p, 1)) + + local_start_n = tl.maximum(((start_m * BLOCK_M + SEQOFFSET) // BLOCK_N) * BLOCK_N, global_end_n) + local_start_n = tl.multiple_of(local_start_n, BLOCK_N) + end_n = tl.minimum((start_m + 1) * BLOCK_M, actual_seqlen_k) + for start_n in range(local_start_n, end_n, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + offs_n = start_n + offs_n_base + + # load k, v + if EVEN_N: + k = tl.load(k_ptrs + start_n * stride_kb, cache_modifier=".cg") + v = tl.load(v_ptrs + start_n * stride_vb, cache_modifier=".cg") + else: + mask_n = offs_n < actual_seqlen_k + k = tl.load(k_ptrs + start_n * stride_kb, mask=mask_n[None, :], other=0.0, cache_modifier=".cg") # (BLOCK_HEADDIM, BLOCK_N) + v = tl.load(v_ptrs + start_n * stride_vb, mask=mask_n[:, None], other=0.0, cache_modifier=".cg") + + # -- compute qk ---- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = tl.dot(q, k, qk, input_precision="tf32") + + if not EVEN_N: # Need to mask out otherwise the softmax is wrong + qk += tl.where(offs_n[None, :] < actual_seqlen_k, 0, float("-inf")) + # if IS_GLOBAL: + qk += tl.where( # True will not mask + (offs_m[:, None] >= offs_n[None, :]) & ### $ + (((SEQOFFSET + offs_m)[:, None] <= offs_n[None, :])) # `local` part so we need not to (start_n + offs_n < window_size_global) + , 0, float("-inf")) + + # -- compute scaling constant --- + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) + alpha = tl.math.exp2((m_i - m_i_new) * qk_scale) + p = tl.math.exp2(qk * qk_scale - m_i_new[:, None] * qk_scale) + # -- scale and update acc: acc *= alpha[:, None]-- + acc *= alpha[:, None] + acc = tl.dot(p.to(q.dtype), v, acc, input_precision="tf32") + + # -- update m_i and l_i -- + l_i = tl.fma(l_i, alpha, tl.sum(p, 1)) + m_i = m_i_new + + acc = acc * (1.0 / l_i[:, None]) # reduce the number of division + l = tl.fma(m_i, softmax_scale, tl.log(l_i)) # log(normalizer) + # initialize pointers to output + if EVEN_M: + tl.store(l_ptrs, l, cache_modifier=".cs") # .cs is for data accessed once + tl.store(out_ptrs, acc, cache_modifier=".cs") + else: + mask_m = offs_m < actual_seqlen_q + tl.store(l_ptrs, l, mask=mask_m, cache_modifier=".cs") + tl.store(out_ptrs, acc, mask=mask_m[:, None], cache_modifier=".cs") + +@triton.heuristics( + { + "BLOCK_M": lambda args: 128, + "BLOCK_N": lambda args: 128, # 64 or 128 + "num_warps": lambda args: 8, + "num_stages": lambda args: 3, + } +) +@triton.jit +def _fwd_kernel( + Q: tl.const, + K: tl.const, + V: tl.const, + Out, + Lse, + softmax_scale: tl.constexpr, + stride_qb: tl.constexpr, + stride_qh: tl.constexpr, + stride_kb: tl.constexpr, + stride_kh: tl.constexpr, + stride_vb: tl.constexpr, + stride_vh: tl.constexpr, + stride_ob: tl.constexpr, + stride_oh: tl.constexpr, + cu_seqlen_q: tl.const, + cu_seqlen_k: tl.const, + max_seqlen_q_rounded, + nheads: tl.constexpr, + nheads_k: tl.constexpr, + window_size_global: tl.constexpr, + window_size_left: tl.constexpr, + d_rounded: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + start_m = tl.program_id(0) + off_b = tl.program_id(2) + + seqlen_q_start = tl.load(( cu_seqlen_q + off_b ), cache_modifier=".cg") # scalar + seqlen_q_end = tl.load(( cu_seqlen_q + off_b + 1 ), cache_modifier=".cg") # scalar + actual_seqlen_q = seqlen_q_end - seqlen_q_start + # invalid grid block for this seq + if actual_seqlen_q <= start_m * BLOCK_M: + return + seqlen_k_start = tl.load(( cu_seqlen_k + off_b ), cache_modifier=".cg") # scalar + seqlen_k_end = tl.load(( cu_seqlen_k + off_b + 1 ), cache_modifier=".cg") # scalar + actual_seqlen_k = seqlen_k_end - seqlen_k_start + + window_size_left = window_size_left if window_size_left >= 0 and window_size_left <= actual_seqlen_k else actual_seqlen_k + window_size_global = window_size_global if window_size_global > 0 and window_size_global < actual_seqlen_k else 0 + SEQOFFSET = actual_seqlen_k - actual_seqlen_q - window_size_left + + off_h = tl.program_id(1) + Q += ( seqlen_q_start * stride_qb + off_h * stride_qh ) + Out += ( seqlen_q_start * stride_ob + off_h * stride_oh ) + Lse += (off_b * nheads + off_h) * max_seqlen_q_rounded + + off_h_kv = off_h * nheads_k // nheads + K += ( seqlen_k_start * stride_kb + off_h_kv * stride_kh ) + V += ( seqlen_k_start * stride_vb + off_h_kv * stride_vh ) + + EVEN_M = actual_seqlen_q % BLOCK_M == 0 + EVEN_N = actual_seqlen_k % BLOCK_N == 0 + + _fwd_kernel_one_row_block( + start_m, + Q, + K, + V, + Out, + Lse, + softmax_scale, + stride_qb, + stride_kb, + stride_vb, + stride_ob, + actual_seqlen_q, + actual_seqlen_k, + window_size_global, + SEQOFFSET, + EVEN_M=EVEN_M, + EVEN_N=EVEN_N, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_HEADDIM=d_rounded, + ) + +@triton.heuristics( + { + "BLOCK_M": lambda args: 128, + # "num_warps": lambda args: 8, + # "num_stages": lambda args: 1, + } +) +@triton.jit +def _bwd_preprocess_do_o_dot( + Out: tl.const, + DO: tl.const, + Delta, + stride_ob: tl.constexpr, + stride_oh: tl.constexpr, + stride_dob: tl.constexpr, + stride_doh: tl.constexpr, + cu_seqlen_q: tl.const, + max_seqlen_q_rounded, + BLOCK_M: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, +): + start_m = tl.program_id(0) + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + seqlen_q_start = tl.load(( cu_seqlen_q + off_b ), cache_modifier=".cg") # scalar + seqlen_q_end = tl.load(( cu_seqlen_q + off_b + 1 ), cache_modifier=".cg") # scalar + actual_seqlen_q = seqlen_q_end - seqlen_q_start + # invalid grid block for this seq + if actual_seqlen_q <= start_m * BLOCK_M: + return + + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, BLOCK_HEADDIM) + # load + o_ptrs = Out + seqlen_q_start * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_ob + offs_d[None, :] + do_ptrs = ( + DO + seqlen_q_start * stride_dob + off_h * stride_doh + offs_m[:, None] * stride_dob + offs_d[None, :] + ) + EVEN_M = actual_seqlen_q % BLOCK_M == 0 + + if EVEN_M: + o = tl.load(o_ptrs, cache_modifier=".cg").to(tl.float32) + do = tl.load(do_ptrs, cache_modifier=".cg").to(tl.float32) + else: + mask_m = (offs_m < actual_seqlen_q)[:, None] + o = tl.load(o_ptrs, mask=mask_m, other=0.0, cache_modifier=".cg").to(tl.float32) + do = tl.load(do_ptrs, mask=mask_m, other=0.0, cache_modifier=".cg").to(tl.float32) + + delta = tl.sum(o * do, axis=1) + # off_b * tl.num_programs(1) + off_h == off_b * nheads + off_h + delta_ptrs = Delta + (off_b * tl.num_programs(1) + off_h) * max_seqlen_q_rounded + offs_m + # write-back + if EVEN_M: + tl.store(delta_ptrs, delta, cache_modifier=".cs") + else: + tl.store(delta_ptrs, delta, mask=(offs_m < actual_seqlen_q), cache_modifier=".cs") + +@triton.heuristics( + { + "BLOCK_M": lambda args: 64, + "BLOCK_N": lambda args: 64, # out of resource: shared memory, Required: 198656, Hardware limit: 166912. Reducing block sizes + "num_warps": lambda args: 4, + "num_stages": lambda args: 2, + } +) +@triton.jit +def _bwd_dk_dv_kernel( + Q: tl.const, + K: tl.const, + V: tl.const, + DO: tl.const, + DK, + DV, + LSE: tl.const, + D: tl.const, + softmax_scale: tl.constexpr, + stride_qb: tl.constexpr, + stride_qh: tl.constexpr, + stride_kb: tl.constexpr, + stride_kh: tl.constexpr, + stride_vb: tl.constexpr, + stride_vh: tl.constexpr, + stride_dob: tl.constexpr, + stride_doh: tl.constexpr, + stride_dkb: tl.constexpr, + stride_dkh: tl.constexpr, + stride_dvb: tl.constexpr, + stride_dvh: tl.constexpr, + nheads: tl.constexpr, + nheads_k: tl.constexpr, + cu_seqlen_q: tl.const, + cu_seqlen_k: tl.const, + max_seqlen_q_rounded, + window_size_global: tl.constexpr, + window_size_left: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + seqlen_k_start = tl.load(( cu_seqlen_k + off_b ), cache_modifier='.cg') # scalar + seqlen_k_end = tl.load(( cu_seqlen_k + off_b + 1 ), cache_modifier='.cg') # scalar + actual_seqlen_k = seqlen_k_end - seqlen_k_start + # invalid grid block for this seq + start_n = tl.program_id(0) + if actual_seqlen_k <= start_n * BLOCK_N: + return + seqlen_q_start = tl.load(( cu_seqlen_q + off_b ), cache_modifier='.cg') # scalar + seqlen_q_end = tl.load(( cu_seqlen_q + off_b + 1 ), cache_modifier='.cg') # scalar + actual_seqlen_q = seqlen_q_end - seqlen_q_start + + window_size_left = window_size_left if window_size_left >= 0 and window_size_left <= actual_seqlen_k else actual_seqlen_k + window_size_global = window_size_global if window_size_global > 0 and window_size_global < actual_seqlen_k else 0 + SEQOFFSET = actual_seqlen_k - actual_seqlen_q - window_size_left + + log2e: tl.constexpr = 1.4426950408889634 + qk_scale = softmax_scale * log2e + + # offset pointers for batch/head + Q += seqlen_q_start * stride_qb + off_h * stride_qh + off_h_kv = off_h * nheads_k // nheads + K += seqlen_k_start * stride_kb + off_h_kv * stride_kh + V += seqlen_k_start * stride_vb + off_h_kv * stride_vh + DO += seqlen_q_start * stride_dob + off_h * stride_doh + DK += seqlen_k_start * stride_dkb + off_h * stride_dkh + DV += seqlen_k_start * stride_dvb + off_h * stride_dvh + # pointer to row-wise quantities in value-like data + off_hb = off_b * nheads + off_h + D += off_hb * max_seqlen_q_rounded + LSE += off_hb * max_seqlen_q_rounded + + begin_m = ((start_n * BLOCK_N) // BLOCK_M) * BLOCK_M + # initialize row/col offsets + offs_qm = begin_m + tl.arange(0, BLOCK_M) + offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_n_slice = offs_n[None, :] + offs_m_base = tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, BLOCK_HEADDIM) + + # k & v transposed here + q_ptrs = Q + (offs_qm[:, None] * stride_qb + offs_d[None, :]) + k_ptrs = K + (offs_n_slice * stride_kb + offs_d[:, None]) # transposed here + v_ptrs = V + (offs_n_slice * stride_vb + offs_d[:, None]) # transposed here + do_ptrs = DO + (offs_qm[:, None] * stride_dob + offs_d[None, :]) + dv_ptrs = DV + (offs_n[:, None] * stride_dvb + offs_d[None, :]) + dk_ptrs = DK + (offs_n[:, None] * stride_dkb + offs_d[None, :]) + # initialize dv and dk + dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) + dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) + + EVEN_M = actual_seqlen_q % BLOCK_M == 0 + EVEN_N = actual_seqlen_k % BLOCK_N == 0 + + # load k , v + if EVEN_N: + k = tl.load(k_ptrs, cache_modifier=".cg") + v = tl.load(v_ptrs, cache_modifier=".cg") + else: + mask_n = offs_n_slice < actual_seqlen_k + k = tl.load(k_ptrs, mask=mask_n, other=0.0, cache_modifier=".cg") + v = tl.load(v_ptrs, mask=mask_n, other=0.0, cache_modifier=".cg") + + # loop over rows + num_block_m = tl.cdiv(actual_seqlen_q, BLOCK_M) + end_m = num_block_m * BLOCK_M # $ + global_end_n = tl.cdiv(window_size_global, BLOCK_N) - 1 + local_start_n = tl.cdiv(actual_seqlen_k-window_size_left, BLOCK_N) - 1 + if (start_n > global_end_n) & (start_n < local_start_n): + end_m = tl.cdiv((start_n + 1) * BLOCK_N - SEQOFFSET, BLOCK_M) * BLOCK_M # $ + for start_m in range(begin_m, end_m, BLOCK_M): + start_m = tl.multiple_of(start_m, BLOCK_M) + offs_m = start_m + offs_m_base + # load q, l, do on-chip + if EVEN_M: + q = tl.load(q_ptrs, cache_modifier=".cg") + do = tl.load(do_ptrs, cache_modifier=".cg") + l = tl.load(LSE + offs_m, cache_modifier=".cg") + Di = tl.load(D + offs_m, cache_modifier=".cg") + else: + mask_m = offs_m < actual_seqlen_q + q = tl.load(q_ptrs, mask=mask_m[:, None], other=0.0, cache_modifier=".cg") + do = tl.load(do_ptrs, mask=mask_m[:, None], other=0.0, cache_modifier=".cg") + l = tl.load(LSE + offs_m, mask=mask_m, cache_modifier=".cg") + Di = tl.load(D + offs_m, mask=mask_m, cache_modifier=".cg") + q_ptrs += BLOCK_M * stride_qb + do_ptrs += BLOCK_M * stride_dob + + # recompute p = softmax(qk, dim=-1).T + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = tl.dot(q, k, qk, input_precision="tf32") + + # Trying to combine the two masks seem to make the result wrong + if not EVEN_N: # Need to mask out otherwise the softmax is wrong + qk += tl.where(offs_n_slice < actual_seqlen_k, 0, float("-inf")) + qk += tl.where( # True will not mask + (offs_m[:, None] >= offs_n_slice) & + (((SEQOFFSET + offs_m)[:, None] <= offs_n_slice) | (offs_n_slice < window_size_global)) + , 0, float("-inf")) + + p = tl.math.exp2(qk * qk_scale - l[:, None] * log2e) + dv = tl.dot(tl.trans(p.to(do.dtype)), do, dv, input_precision="tf32") + # compute dp = dot(v, do) + dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + dp = tl.dot(do, v, dp, input_precision="tf32") + + ds = (p * (dp - Di[:, None])).to(q.dtype) + # compute dk = dot(ds.T, q) + dk = tl.dot(tl.trans(ds), q, dk, input_precision="tf32") + + dk *= softmax_scale + # write-back + if EVEN_N: + tl.store(dv_ptrs, dv.to(k.dtype), cache_modifier=".cs") + tl.store(dk_ptrs, dk.to(k.dtype), cache_modifier=".cs") + else: + mask_n = offs_n < actual_seqlen_k + tl.store(dk_ptrs, dk.to(k.dtype), mask=mask_n[:, None], cache_modifier=".cs") + tl.store(dv_ptrs, dv.to(k.dtype), mask=mask_n[:, None], cache_modifier=".cs") + + +@triton.heuristics( + { + "BLOCK_M": lambda args: 64, + "BLOCK_N": lambda args: 64, # 64 or 128 + "num_warps": lambda args: 4, # must be 4, or will have a race condition + "num_stages": lambda args: 2, + } +) +@triton.jit +def _bwd_dq_kernel( + Q: tl.const, + K: tl.const, + V: tl.const, + DO: tl.const, + DQ, + LSE: tl.const, + D: tl.const, + softmax_scale: tl.constexpr, + stride_qb: tl.constexpr, + stride_qh: tl.constexpr, + stride_kb: tl.constexpr, + stride_kh: tl.constexpr, + stride_vb: tl.constexpr, + stride_vh: tl.constexpr, + stride_dob: tl.constexpr, + stride_doh: tl.constexpr, + stride_dqb: tl.constexpr, + stride_dqh: tl.constexpr, + nheads: tl.constexpr, + nheads_k: tl.constexpr, + cu_seqlen_q: tl.const, + cu_seqlen_k: tl.const, + max_seqlen_q_rounded, + window_size_global: tl.constexpr, + window_size_left: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + seqlen_q_start = tl.load(( cu_seqlen_q + off_b ), cache_modifier='.cg') # scalar + seqlen_q_end = tl.load(( cu_seqlen_q + off_b + 1 ), cache_modifier='.cg') # scalar + actual_seqlen_q = seqlen_q_end - seqlen_q_start + start_m = tl.program_id(0) + # invalid grid block for this seq + if actual_seqlen_q <= start_m * BLOCK_N: + return + + seqlen_k_start = tl.load(( cu_seqlen_k + off_b ), cache_modifier='.cg') # scalar + seqlen_k_end = tl.load(( cu_seqlen_k + off_b + 1 ), cache_modifier='.cg') # scalar + actual_seqlen_k = seqlen_k_end - seqlen_k_start + + window_size_left = window_size_left if window_size_left >= 0 and window_size_left <= actual_seqlen_k else actual_seqlen_k + window_size_global = window_size_global if window_size_global > 0 and window_size_global < actual_seqlen_k else 0 + SEQOFFSET = actual_seqlen_k - actual_seqlen_q - window_size_left + + log2e: tl.constexpr = 1.4426950408889634 + qk_scale = softmax_scale * log2e + + # offset pointers for batch/head + Q += seqlen_q_start * stride_qb + off_h * stride_qh + off_h_kv = off_h * nheads_k // nheads + K += seqlen_k_start * stride_kb + off_h_kv * stride_kh + V += seqlen_k_start * stride_vb + off_h_kv * stride_vh + DO += seqlen_q_start * stride_dob + off_h * stride_doh + DQ += seqlen_q_start * stride_dqb + off_h * stride_dqh + # pointer to row-wise quantities in value-like data + off_hb = off_b * nheads + off_h + D += off_hb * max_seqlen_q_rounded + LSE += off_hb * max_seqlen_q_rounded + + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, BLOCK_HEADDIM) + offs_m_slice = offs_m[:, None] + q_ptrs = (Q + (offs_m_slice * stride_qb + offs_d[None, :])) + do_ptrs = (DO + (offs_m_slice * stride_dob + offs_d[None, :])) + dq_ptrs = DQ + (offs_m_slice * stride_dqb + offs_d[None, :]) + + EVEN_M = actual_seqlen_q % BLOCK_M == 0 + EVEN_N = actual_seqlen_k % BLOCK_N == 0 + + # load q & do: it will stay in SRAM throughout + # load q & do & load l & delta: it will stay in SRAM throughout + if EVEN_M: + q = tl.load(q_ptrs, cache_modifier=".cg") + do = tl.load(do_ptrs, cache_modifier=".cg") + l = tl.load(LSE + offs_m, cache_modifier=".cg") + Di = tl.load(D + offs_m, cache_modifier=".cg") + else: + mask_m = offs_m < actual_seqlen_q + l = tl.load(LSE + offs_m, mask=mask_m, cache_modifier=".cg") + Di = tl.load(D + offs_m, mask=mask_m, cache_modifier=".cg") + q = tl.load(q_ptrs, mask=mask_m[:, None], cache_modifier=".cg") + do = tl.load(do_ptrs, mask=mask_m[:, None], cache_modifier=".cg") + l_slice_scale = l[:, None] * log2e + + # dq init + dq = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) + + # load k, v. k, v transposed here + offs_n_base_slice = (tl.arange(0, BLOCK_N))[None, :] + k_ptrs = (K + (offs_d[:, None] + offs_n_base_slice * stride_kb)) # (BLOCK_HEADDIM, BLOCK_N) + v_ptrs = (V + (offs_d[:, None] + offs_n_base_slice * stride_vb)) + global_end_n = tl.cdiv(window_size_global, BLOCK_N) * BLOCK_N + global_end_n = tl.multiple_of(global_end_n, BLOCK_N) + # loop of global part(could be 0) + for start_n in range(0, global_end_n, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + offs_n_slice = start_n + offs_n_base_slice + + # load k, v + if EVEN_N: + k = tl.load(k_ptrs + start_n * stride_kb, cache_modifier=".cg") + v = tl.load(v_ptrs + start_n * stride_vb, cache_modifier=".cg") + else: + mask_n = offs_n_slice < actual_seqlen_k + k = tl.load(k_ptrs + start_n * stride_kb, mask=mask_n, other=0.0, cache_modifier=".cg") # (BLOCK_HEADDIM, BLOCK_N) + v = tl.load(v_ptrs + start_n * stride_vb, mask=mask_n, other=0.0, cache_modifier=".cg") + + # -- compute qk ---- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = tl.dot(q, k, qk, input_precision="tf32") + + # no need to mask for EVEN_N in `global` case + # if IS_GLOBAL: + qk += tl.where( # True will not mask + (offs_m_slice >= offs_n_slice) & + ((SEQOFFSET + offs_m_slice <= offs_n_slice) | (offs_n_slice < window_size_global)) + , 0, float("-inf")) + + # -- compute p --- + p = tl.math.exp2(qk * qk_scale - l_slice_scale) + # compute dq = dot(p, do) + dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + dp = tl.dot(do.to(q.dtype), v, dp, input_precision="tf32") + + ds = (p * (dp - Di[:, None])).to(q.dtype) + dq = tl.dot(ds, tl.trans(k), dq) + + local_start_n = tl.maximum(((start_m * BLOCK_M + SEQOFFSET) // BLOCK_N) * BLOCK_N, global_end_n) + local_start_n = tl.multiple_of(local_start_n, BLOCK_N) + end_n = tl.minimum((start_m + 1) * BLOCK_M, actual_seqlen_k) + for start_n in range(local_start_n, end_n, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + offs_n_slice = start_n + offs_n_base_slice + + # load k, v + if EVEN_N: + k = tl.load(k_ptrs + start_n * stride_kb, cache_modifier=".cg") + v = tl.load(v_ptrs + start_n * stride_vb, cache_modifier=".cg") + else: + mask_n = offs_n_slice < actual_seqlen_k + k = tl.load(k_ptrs + start_n * stride_kb, mask=mask_n, other=0.0, cache_modifier=".cg") # (BLOCK_HEADDIM, BLOCK_N) + v = tl.load(v_ptrs + start_n * stride_vb, mask=mask_n, other=0.0, cache_modifier=".cg") + + # -- compute qk ---- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = tl.dot(q, k, qk, input_precision="tf32") + + if not EVEN_N: # Need to mask out otherwise the softmax is wrong + qk += tl.where(offs_n_slice < actual_seqlen_k, 0, float("-inf")) + # if IS_GLOBAL: + qk += tl.where( # True will not mask + (offs_m_slice >= offs_n_slice) & ### $ + ((SEQOFFSET + offs_m_slice <= offs_n_slice)) # `local` part so we need not to (start_n + offs_n < window_size_global) + , 0, float("-inf")) + + # -- compute p --- + p = tl.math.exp2(qk * qk_scale - l_slice_scale) + # compute dq = dot(p, do) + dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + dp = tl.dot(do.to(q.dtype), v, dp, input_precision="tf32") + + ds = (p * (dp - Di[:, None])).to(q.dtype) + dq = tl.dot(ds, tl.trans(k), dq) + + dq *= softmax_scale + if EVEN_M: + tl.store(dq_ptrs, dq.to(q.dtype), cache_modifier=".cs") + else: + tl.store(dq_ptrs, dq.to(q.dtype), mask=offs_m_slice < actual_seqlen_q, cache_modifier=".cs") + +class FlashAttnVarlenFunc(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, cu_seqlen_q, cu_seqlen_k, max_seqlen_q, max_seqlen_k, + bias=None, causal=False, softmax_scale=None, window_size=(-1,-1,-1), + ): + """ + q: ((b s), nheads, headdim) + k, v: ((b s) nheads, headdim) + cu_seqlen_q, cu_seqlen_k: (batch+1,),torch.Tensor, the cumulative seqlen + bias: deleted. + """ + # Make sure that the last dimension is contiguous + _, num_heads, d = q.shape + _, num_heads_k, dk = k.shape + assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1 + assert num_heads % num_heads_k == 0, "num_heads must be divisible by num_heads_k" + assert d == dk and dk == v.shape[-1] and num_heads_k == v.shape[-2], "num_heads and head dimensions must match" + assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type" + assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16" + assert q.is_cuda and k.is_cuda and v.is_cuda, "All tensors must be sent to GPU" + # I choose not to support the case where `d` is not a power of 2, + # which is sufficient for current LLM and simplifies the mask for load/store. + assert d in {16, 32, 64, 128}, "Only support d in {16, 32, 64, 128}" + + # In training, `window_size_left + window_size_global` is never larger than `max_seqlen_k`. + # in training, causal is always true. + # causal=True + softmax_scale = softmax_scale or 1.0 / math.sqrt(d) + ctx.softmax_scale = softmax_scale + ctx.window_size_global, ctx.window_size_left, _ = window_size + ctx.cu_seqlen_q = cu_seqlen_q + ctx.cu_seqlen_k = cu_seqlen_k + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + + seqlen_q_rounded = math.ceil(max_seqlen_q / 128) * 128 + batch = cu_seqlen_k.shape[0] - 1 + lse = torch.empty((batch, num_heads, seqlen_q_rounded), device=q.device, dtype=torch.float32) + o = torch.empty_like(q) + # using 3d grid to avoid div & rem + grid = lambda META: (triton.cdiv(max_seqlen_q, META["BLOCK_M"]), num_heads, batch) + _fwd_kernel[grid]( + q, + k, + v, + o, + lse, + softmax_scale, + q.stride(0), + q.stride(1), + k.stride(0), + k.stride(1), + v.stride(0), + v.stride(1), + o.stride(0), + o.stride(1), + cu_seqlen_q, + cu_seqlen_k, + seqlen_q_rounded, + num_heads, + num_heads_k, + window_size[0], # window_size_global + window_size[1], # window_size_left + # IS_CAUSAL=causal, BLOCK_HEADDIM=d, + d, # d_rounded (BLOCK_HEADDIM) actually + # BLOCK_M=128, + # BLOCK_N=64, + # num_warps=num_warps, + # num_stages=1, + ) + + ctx.save_for_backward(q, k, v, o, lse) + return o + + @staticmethod + def backward(ctx, do): + q, k, v, o, lse = ctx.saved_tensors + # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd + # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version. + with torch.inference_mode(): + if do.stride(-1) != 1: do = do.contiguous() + bs_q, num_heads, d = q.shape + bs_k, num_heads_k, _ = k.shape + kv_group_size = num_heads // num_heads_k + batch = ctx.cu_seqlen_q.shape[0] - 1 + delta = torch.empty_like(lse) + max_seqlen_q_rounded = lse.shape[-1] + + grid = lambda META: (triton.cdiv(ctx.max_seqlen_q, META["BLOCK_M"]), num_heads, batch) + _bwd_preprocess_do_o_dot[grid]( + o, + do, + delta, + o.stride(0), + o.stride(1), + do.stride(0), + do.stride(1), + ctx.cu_seqlen_q, # Tensor + max_seqlen_q_rounded, + BLOCK_HEADDIM=d, + ) + dk_expanded = torch.empty((bs_k, num_heads, d), dtype=do.dtype, device=do.device) + dv_expanded = torch.empty((bs_k, num_heads, d), dtype=do.dtype, device=do.device) + grid = lambda META: (triton.cdiv(ctx.max_seqlen_k, META["BLOCK_N"]), num_heads, batch) + _bwd_dk_dv_kernel[grid]( + q, + k, + v, + do, + dk_expanded, + dv_expanded, + lse, + delta, + ctx.softmax_scale, + q.stride(0), + q.stride(1), + k.stride(0), + k.stride(1), + v.stride(0), + v.stride(1), + do.stride(0), + do.stride(1), + dk_expanded.stride(0), + dk_expanded.stride(1), + dv_expanded.stride(0), + dv_expanded.stride(1), + num_heads, + num_heads_k, + ctx.cu_seqlen_q, + ctx.cu_seqlen_k, + max_seqlen_q_rounded, + ctx.window_size_global, + ctx.window_size_left, + d, #BLOCK_HEADDIM=d, + # BLOCK_M=128, BLOCK_N=128, + # num_warps=8, + # num_stages=1, + ) + dk = dk_expanded.reshape(bs_k, num_heads_k, kv_group_size, d).sum(dim=2, keepdim=False) + dv = dv_expanded.reshape(bs_k, num_heads_k, kv_group_size, d).sum(dim=2, keepdim=False) + dq = torch.zeros_like(q) + grid = lambda META: (triton.cdiv(ctx.max_seqlen_q, META["BLOCK_M"]), num_heads, batch) + _bwd_dq_kernel[grid]( + q, + k, + v, + do, + dq, + lse, + delta, + ctx.softmax_scale, + q.stride(0), + q.stride(1), + k.stride(0), + k.stride(1), + v.stride(0), + v.stride(1), + do.stride(0), + do.stride(1), + dq.stride(0), + dq.stride(1), + num_heads, + num_heads_k, + ctx.cu_seqlen_q, + ctx.cu_seqlen_k, + max_seqlen_q_rounded, + ctx.window_size_global, + ctx.window_size_left, + d, #BLOCK_HEADDIM=d, + # BLOCK_M=128, BLOCK_N=128, + # num_warps=8, + # num_stages=1, + ) + + # This is how many gradients you have to return as many arguments forward + return dq, dk, dv, None, None, None, None, None, None, None, None + +flash_attn_varlen_func = FlashAttnVarlenFunc.apply + +class FlashAttnVarlenKVPackedFunc(torch.autograd.Function): + @staticmethod + def forward(ctx, q, kv, cu_seqlen_q, cu_seqlen_k, max_seqlen_q, max_seqlen_k, + bias=None, causal=False, softmax_scale=None, window_size=(-1,-1,-1), + ): + """ + q: ((b s), nheads, headdim) + k, v: ((b s) nheads, headdim) + cu_seqlen_q, cu_seqlen_k: (batch+1,),torch.Tensor, the cumulative seqlen + bias: deleted. + """ + # Make sure that the last dimension is contiguous + maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x + q, k, v = maybe_contiguous(q), maybe_contiguous(kv[:, 0]), maybe_contiguous(kv[:, 1]) + _, num_heads, d = q.shape + _, num_heads_k, dk = k.shape + assert num_heads % num_heads_k == 0, "num_heads must be divisible by num_heads_k" + assert d == dk and dk == v.shape[-1] and num_heads_k == v.shape[-2], "num_heads and head dimensions must match" + assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type" + assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16" + assert q.is_cuda and k.is_cuda and v.is_cuda, "All tensors must be sent to GPU" + # I choose not to support the case where `d` is not a power of 2, + # which is sufficient for current LLM and simplifies the mask for load/store. + assert d in {16, 32, 64, 128}, "Only support d in {16, 32, 64, 128}" + + # In training, `window_size_left + window_size_global` is never larger than `max_seqlen_k`. + # in training, causal is always true. + # causal=True + softmax_scale = softmax_scale or 1.0 / math.sqrt(d) + ctx.softmax_scale = softmax_scale + ctx.window_size_global, ctx.window_size_left, _ = window_size + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + + seqlen_q_rounded = math.ceil(max_seqlen_q / 128) * 128 + batch = cu_seqlen_k.shape[0] - 1 + lse = torch.empty((batch, num_heads, seqlen_q_rounded), device=q.device, dtype=torch.float32) + o = torch.empty_like(q) + # using 3d grid to avoid div & rem + grid = lambda META: (triton.cdiv(max_seqlen_q, META["BLOCK_M"]), num_heads, batch) + _fwd_kernel[grid]( + q, + k, + v, + o, + lse, + softmax_scale, + q.stride(0), + q.stride(1), + k.stride(0), + k.stride(1), + v.stride(0), + v.stride(1), + o.stride(0), + o.stride(1), + cu_seqlen_q, + cu_seqlen_k, + seqlen_q_rounded, + num_heads, + num_heads_k, + window_size[0], # window_size_global + window_size[1], # window_size_left + # IS_CAUSAL=causal, BLOCK_HEADDIM=d, + d, # d_rounded (BLOCK_HEADDIM) actually + # BLOCK_M=128, + # BLOCK_N=64, + # num_warps=num_warps, + # num_stages=1, + ) + ctx.save_for_backward(q, k, v, o, lse, cu_seqlen_q, cu_seqlen_k) + return o + + @staticmethod + def backward(ctx, do): + q, k, v, o, lse, cu_seqlen_q, cu_seqlen_k = ctx.saved_tensors + # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd + # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version. + with torch.inference_mode(): + maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x + q, k, v, o, do = maybe_contiguous(q), maybe_contiguous(k), maybe_contiguous(v), maybe_contiguous(o), maybe_contiguous(do) + bs_q, num_heads, d = q.shape + bs_k, num_heads_k, _ = k.shape + kv_group_size = num_heads // num_heads_k + batch = cu_seqlen_q.shape[0] - 1 + delta = torch.empty_like(lse) + max_seqlen_q_rounded = lse.shape[-1] + + grid = lambda META: (triton.cdiv(ctx.max_seqlen_q, META["BLOCK_M"]), num_heads, batch) + _bwd_preprocess_do_o_dot[grid]( + o, + do, + delta, + o.stride(0), + o.stride(1), + do.stride(0), + do.stride(1), + cu_seqlen_q, # Tensor + max_seqlen_q_rounded, + BLOCK_HEADDIM=d, + ) + dk_expanded = torch.empty((bs_k, num_heads, d), dtype=do.dtype, device=do.device) + dv_expanded = torch.empty((bs_k, num_heads, d), dtype=do.dtype, device=do.device) + # BLOCK_M = 128 + # BLOCK_N = 64 + # num_warps = 4 + grid = lambda META: (triton.cdiv(ctx.max_seqlen_k, META["BLOCK_N"]), num_heads, batch) + _bwd_dk_dv_kernel[grid]( + q, + k, + v, + do, + dk_expanded, + dv_expanded, + lse, + delta, + ctx.softmax_scale, + q.stride(0), + q.stride(1), + k.stride(0), + k.stride(1), + v.stride(0), + v.stride(1), + do.stride(0), + do.stride(1), + dk_expanded.stride(0), + dk_expanded.stride(1), + dv_expanded.stride(0), + dv_expanded.stride(1), + num_heads, + num_heads_k, + cu_seqlen_q, + cu_seqlen_k, + max_seqlen_q_rounded, + ctx.window_size_global, + ctx.window_size_left, + d, #BLOCK_HEADDIM=d, + # BLOCK_M=128, BLOCK_N=128, + # num_warps=8, + # num_stages=1, + ) + kv_shape = k.shape[:-2] + (2, *k.shape[-2:]) + dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device) + dkv[:, 0] = dk_expanded.reshape(bs_k, num_heads_k, kv_group_size, d).sum(dim=2, keepdim=False) + dkv[:, 1] = dv_expanded.reshape(bs_k, num_heads_k, kv_group_size, d).sum(dim=2, keepdim=False) + dq = torch.zeros_like(q) + grid = lambda META: (triton.cdiv(ctx.max_seqlen_q, META["BLOCK_M"]), num_heads, batch) + _bwd_dq_kernel[grid]( + q, + k, + v, + do, + dq, + lse, + delta, + ctx.softmax_scale, + q.stride(0), + q.stride(1), + k.stride(0), + k.stride(1), + v.stride(0), + v.stride(1), + do.stride(0), + do.stride(1), + dq.stride(0), + dq.stride(1), + num_heads, + num_heads_k, + cu_seqlen_q, + cu_seqlen_k, + max_seqlen_q_rounded, + ctx.window_size_global, + ctx.window_size_left, + d, #BLOCK_HEADDIM=d, + # BLOCK_M=128, BLOCK_N=128, + # num_warps=8, + # num_stages=1, + ) + # This is how many gradients you have to return as many arguments forward + return dq, dkv, None, None, None, None, None, None, None, None + +flash_attn_varlen_kvpacked_func = FlashAttnVarlenKVPackedFunc.apply \ No newline at end of file