diff --git a/awq/modules/fused/attn.py b/awq/modules/fused/attn.py index 9e348b51..87041ab3 100644 --- a/awq/modules/fused/attn.py +++ b/awq/modules/fused/attn.py @@ -8,11 +8,11 @@ try: - import awq_ft_ext + from flash_attn import flash_attn_func, flash_attn_with_kvcache - FT_INSTALLED = True + FA_INSTALLED = True except: - FT_INSTALLED = False + FA_INSTALLED = False HF_NEW_CACHE_FORMAT = False @@ -28,6 +28,7 @@ class RoPE(nn.Module): def __init__(self, head_dim, max_seq_len, device, rope_theta): super(RoPE, self).__init__() + self.head_dim = head_dim self.freqs_cis = nn.Parameter( self.precompute_freqs_cis(head_dim, max_seq_len, rope_theta).to(device), requires_grad=False, @@ -49,7 +50,23 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] return freqs_cis.view(*shape) - def forward(self, xq: torch.Tensor, xk: torch.Tensor, start_pos: int, seqlen: int): + def forward( + self, + xq: torch.Tensor, + xk: torch.Tensor, + start_pos: int, + seqlen: int, + partial: bool = False, + ): + if partial: + xq, xq_pass = ( + xq[..., : self.head_dim], + xq[..., self.head_dim :], + ) + xk, xk_pass = ( + xk[..., : self.head_dim], + xk[..., self.head_dim :], + ) xq_ = torch.view_as_complex( xq.float().reshape(*xq.shape[:-1], 2, -1).transpose(-2, -1).contiguous() ) @@ -62,6 +79,10 @@ def forward(self, xq: torch.Tensor, xk: torch.Tensor, start_pos: int, seqlen: in xq_out = torch.view_as_real(xq_ * freqs_cis).transpose(-2, -1).flatten(3) xk_out = torch.view_as_real(xk_ * freqs_cis).transpose(-2, -1).flatten(3) + if partial: + xq = torch.cat((xq, xq_pass), dim=-1) + xk = torch.cat((xk, xk_pass), dim=-1) + return xq_out.type_as(xq), xk_out.type_as(xk) @@ -118,7 +139,7 @@ def __init__( rope_theta=10000, partial_rotary_factor=1.0, head_dim=None, - attn_logit_softcapping=None, + attn_logit_softcapping=0.0, **kwargs ): super().__init__() @@ -147,18 +168,18 @@ def __init__( # attention shapes for self attention self.attention_shapes = get_attention_shapes( attention_shapes, - max_seq_len, - self.cache_batch_size, n_heads, n_kv_heads, self.head_dim, ) # cache store that rolls cache self.cache = WindowedCache( - self.attention_shapes["cache_v"], - self.attention_shapes["cache_k"], - self.max_seq_len, - dev, + cache_batch_size=self.cache_batch_size, + n_heads=n_heads, + n_kv_heads=n_kv_heads, + head_dim=self.head_dim, + max_seq_len=self.max_seq_len, + device=dev, ) if use_alibi: @@ -174,13 +195,10 @@ def __init__( if kwargs.get("is_neox") is not None: self.is_neox = kwargs["is_neox"] - + self.attn_logit_softcapping = attn_logit_softcapping - self.use_sdpa = kwargs.get("use_sdpa", False) - def forward( - self, hidden_states: torch.Tensor, attention_mask=None, *args, **kwargs - ): + def forward(self, hidden_states: torch.Tensor, *args, **kwargs): bsz, seqlen, _ = hidden_states.shape # Reallocate cache if batch size changes @@ -196,21 +214,27 @@ def forward( self.start_pos = 0 hf_is_generating = False - hf_is_first_forward = "past_key_value" in kwargs and kwargs["past_key_value"] is None - hf_is_new_cache_first_forward = "past_key_value" in kwargs and isinstance(kwargs["past_key_value"], DynamicCache) and kwargs["past_key_value"].get_seq_length() == 0 + hf_is_first_forward = ( + "past_key_value" in kwargs and kwargs["past_key_value"] is None + ) + hf_is_new_cache_first_forward = ( + "past_key_value" in kwargs + and isinstance(kwargs["past_key_value"], DynamicCache) + and kwargs["past_key_value"].get_seq_length() == 0 + ) if self.is_hf_transformers and "use_cache" in kwargs: hf_is_generating = kwargs["use_cache"] - # print(kwargs["past_key_value"].get_seq_length()) - # In case we re-generate, we need to refresh the starting position # to 0. We detect it by checking if `past_key_values` is set to None, # which indicates that we are on the first step of `generate()`. # This is only applicable for `transformers` integration - if (self.is_hf_transformers and (hf_is_first_forward or hf_is_new_cache_first_forward)) or (self.is_hf_transformers and not hf_is_generating): + if ( + self.is_hf_transformers + and (hf_is_first_forward or hf_is_new_cache_first_forward) + ) or (self.is_hf_transformers and not hf_is_generating): self.start_pos = 0 - xqkv = self.qkv_proj(hidden_states) xqkv = xqkv.view((bsz, seqlen) + self.attention_shapes["xqkv_view"]) @@ -219,114 +243,47 @@ def forward( xk = self.attention_shapes["xk_slice"](xqkv) xv = self.attention_shapes["xv_slice"](xqkv) - if seqlen > 1 or self.partial_rotary_factor < 1 or not FT_INSTALLED: - xq = xq.view((bsz, seqlen) + self.attention_shapes["xq_view"]) - xk = xk.view((bsz, seqlen) + self.attention_shapes["xk_view"]) - xv = xv.view((bsz, seqlen) + self.attention_shapes["xv_view"]) - - if not self.use_alibi: - # Partial rotary embedding - if self.partial_rotary_factor < 1: - xq_rot, xq_pass = ( - xq[..., : self.rotary_dim], - xq[..., self.rotary_dim :], - ) - xk_rot, xk_pass = ( - xk[..., : self.rotary_dim], - xk[..., self.rotary_dim :], - ) - xq_rot, xk_rot = self.rope.forward(xq_rot, xk_rot, self.start_pos, seqlen) - xq = torch.cat((xq_rot, xq_pass), dim=-1) - xk = torch.cat((xk_rot, xk_pass), dim=-1) - else: - xq, xk = self.rope.forward(xq, xk, self.start_pos, seqlen) - - values_store = xv.transpose(2, 1) - keys_store = ( - xk.reshape((bsz, seqlen) + self.attention_shapes["xk_reshape"]) - .permute(0, 2, 3, 1, 4) - .contiguous() + if not self.use_alibi: + xq, xk = self.rope.forward( + xq, xk, self.start_pos, seqlen, partial=self.partial_rotary_factor < 1 ) - self.cache.to(xq) - self.cache.update_kv(values_store, keys_store, bsz, self.start_pos, seqlen) - - # Only necessary to retrieve from cache when we are not processing context - if seqlen == 1: - xv, xk = self.cache.get_kv(bsz, self.start_pos, seqlen, self.head_dim) - - keys = xk - values = xv - - if self.n_kv_groups != 0: - keys = torch.repeat_interleave(keys, dim=2, repeats=self.n_kv_groups) - values = torch.repeat_interleave( - values, dim=2, repeats=self.n_kv_groups - ) - - xq = xq.transpose(1, 2) - keys = keys.transpose(1, 2) - values = values.transpose(1, 2) - - # Used in Gemma2 - if self.attn_logit_softcapping is not None: - scores = scores / self.attn_logit_softcapping - scores = torch.tanh(scores) - scores = scores * self.attn_logit_softcapping - - if self.use_sdpa: - causal_mask = attention_mask - if attention_mask is not None: - causal_mask = causal_mask[:, :, :, : keys.shape[-2]] - is_causal = True if causal_mask is None and seqlen > 1 else False - output = torch.nn.functional.scaled_dot_product_attention( - xq, - keys, - values, - attn_mask=causal_mask, - dropout_p=0.0, - is_causal=is_causal, - ) - else: - scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) - if self.use_alibi: - scores = self.alibi.forward(scores, seqlen) - - # When seqlen is 1, there is nothing else to attend to - if attention_mask is not None and seqlen > 1: - # For llama-arch, the causal mask is preallocated with bsz x 1 x max_seq_len x max_seq_len, thus we - # need to slice it - if attention_mask.shape[-1] != seqlen: - attention_mask = attention_mask[:, :, :seqlen, :seqlen] - - scores = ( - scores + attention_mask - ) # (bs, n_local_heads, slen, cache_len + slen) - scores = F.softmax(scores.float(), dim=-1).type_as(xq) - output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim) - - attention_weight = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) + self.cache.to(xq) + self.cache.update_kv( + values_store=xv, + keys_store=xk, + batch_size=bsz, + start_pos=self.start_pos, + seqlen=seqlen, + ) + + if seqlen > 1: + output = flash_attn_func( + q=xq, + k=xk, + v=xv, + causal=True, + alibi_slopes=self.alibi.slopes if self.alibi is not None else None, + softcap=self.attn_logit_softcapping, + ) else: - xq = xq.view((bsz,) + self.attention_shapes["single_xq_view"]) - xk = xk.view((bsz,) + self.attention_shapes["single_xk_view"]) - xv = xv.view((bsz,) + self.attention_shapes["single_xv_view"]) - - alibi_slopes = self.alibi.slopes if self.alibi is not None else None - attention_weight = awq_ft_ext.single_query_attention( - xq, # query - xk, # key - xv, # value - self.cache.k, # key cache - self.cache.v, # value cache - None, # length per sample - alibi_slopes, # alibi slopes - self.start_pos, # timestep - self.rotary_dim, # rotary embedding dimension - self.rope_theta, # rotary embedding base - self.is_neox, # is neox + cache_seqlens = torch.full( + (bsz,), self.start_pos + seqlen, dtype=torch.int32, device=xq.device + ) + + output = flash_attn_with_kvcache( + q=xq, + k=xk, + k_cache=self.cache.k, + v=xv, + v_cache=self.cache.v, + cache_seqlens=cache_seqlens, + causal=True, + alibi_slopes=self.alibi.slopes if self.alibi is not None else None, + softcap=self.attn_logit_softcapping, ) - attention_weight = attention_weight.reshape(bsz, 1, -1) + attention_weight = output.view(bsz, seqlen, -1) attn_output = self.o_proj(attention_weight) self.start_pos += seqlen @@ -338,7 +295,6 @@ def forward( # about past key length past_key_value = [torch.zeros(1, 1, self.start_pos, 1)] - if HF_NEW_CACHE_FORMAT and self.is_hf_transformers: new_cache = DynamicCache() new_cache.update(past_key_value[0], past_key_value[0], layer_idx=0) diff --git a/awq/modules/fused/block.py b/awq/modules/fused/block.py index a4a02f2b..785aab1a 100644 --- a/awq/modules/fused/block.py +++ b/awq/modules/fused/block.py @@ -41,23 +41,17 @@ def __init__( def forward( self, hidden_states, - past_key_value, - attn_bias=None, - attention_mask=None, - is_causal=None, ): norm_out = self.norm_1(hidden_states) - attn_output, _, past_key_value = self.attn.forward( + attn_output, _, _ = self.attn.forward( hidden_states=norm_out, - past_key_value=past_key_value, - attention_mask=attention_mask, ) h = hidden_states.to(attn_output.device) + attn_output out = self.moe.forward(self.norm_2(h)) out = h + out - return out, None, past_key_value + return out class LlamaLikeBlock(nn.Module): @@ -106,7 +100,6 @@ def __init__( rope_theta=rope_theta, partial_rotary_factor=partial_rotary_factor, head_dim=head_dim, - use_sdpa=True, ).to(dev) self.norm_2 = norm_2.to(dev) self.mlp = mlp.to(dev) @@ -115,22 +108,16 @@ def __init__( def forward( self, hidden_states, - past_key_value, - attn_bias=None, - attention_mask=None, - is_causal=None, ): norm_out = self.norm_1(hidden_states) - attn_output, _, past_key_value = self.attn.forward( + attn_output, _, _ = self.attn.forward( hidden_states=norm_out, - past_key_value=past_key_value, - attention_mask=attention_mask, ) h = hidden_states.to(attn_output.device) + attn_output out = h + self.mlp.forward(self.norm_2(h)) - return out, None, past_key_value + return out class Gemma2LikeBlock(nn.Module): @@ -144,8 +131,8 @@ def __init__( mlp, norm_1, norm_2, - norm_3, - norm_4, + norm_3, + norm_4, dev, max_seq_len, rope_theta=10000, @@ -188,30 +175,24 @@ def __init__( def forward( self, hidden_states, - past_key_value, - attn_bias=None, - attention_mask=None, - is_causal=None, ): residual = hidden_states hidden_states = self.norm_1(hidden_states) - hidden_states, _, past_key_value = self.attn.forward( + hidden_states, _, _ = self.attn.forward( hidden_states=hidden_states, - past_key_value=past_key_value, - attention_mask=attention_mask, ) hidden_states = self.norm_2(hidden_states) hidden_states = residual + hidden_states - + residual = hidden_states hidden_states = self.norm_3(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = self.norm_4(hidden_states) out = residual + hidden_states - return out, None, past_key_value + return out class CohereBlock(nn.Module): @@ -224,7 +205,6 @@ def __init__( o_proj, mlp, norm_1, - # norm_2, dev, max_seq_len, rope_theta=10000, @@ -256,29 +236,22 @@ def __init__( head_dim=head_dim, is_neox=False, ).to(dev) - # self.norm_2 = norm_2.to(dev) self.mlp = mlp.to(dev) self.device = dev def forward( self, hidden_states, - past_key_value, - attn_bias=None, - attention_mask=None, - is_causal=None, ): norm_out = self.norm_1(hidden_states) - attn_output, _, past_key_value = self.attn.forward( + attn_output, _, _ = self.attn.forward( hidden_states=norm_out, - past_key_value=past_key_value, - attention_mask=attention_mask, ) h = hidden_states.to(attn_output.device) + attn_output out = h + self.mlp.forward(norm_out) - return out, None, past_key_value + return out class MPTBlock(nn.Module): @@ -316,24 +289,15 @@ def __init__( def forward( self, hidden_states, - past_key_value, - attn_bias=None, - attention_mask=None, - is_causal=None, ): norm_out = self.norm_1(hidden_states) - attn_output, _, past_key_value = self.attn.forward( + attn_output, _, _ = self.attn.forward( hidden_states=norm_out, - past_key_value=past_key_value, - attention_mask=attention_mask, - position_ids=None, - output_attentions=False, - use_cache=True, ) h = hidden_states.to(attn_output.device) + attn_output out = h + self.ffn.forward(self.norm_2(h)) - return out, None, past_key_value + return out class FalconDecoderLayer(nn.Module): @@ -423,10 +387,6 @@ def _get_attention_shapes(self, n_heads, max_seq_len, head_dim): def forward( self, hidden_states, - past_key_value, - attn_bias=None, - attention_mask=None, - is_causal=None, ): if self.new_decoder_arch: layernorm_out = self.ln_attn(hidden_states) @@ -434,13 +394,8 @@ def forward( else: layernorm_out = self.input_layernorm(hidden_states) - attn_output, _, past_key_value = self.attn.forward( + attn_output, _, _ = self.attn.forward( hidden_states=layernorm_out, - past_key_value=past_key_value, - attention_mask=attention_mask, - position_ids=None, - output_attentions=False, - use_cache=True, ) h_attn = hidden_states.to(attn_output.device) + attn_output @@ -452,7 +407,7 @@ def forward( out = h_attn + h_mlp - return out, None, past_key_value + return out class Phi3Block(nn.Module): @@ -509,19 +464,13 @@ def __init__( def forward( self, hidden_states, - past_key_value, - attn_bias=None, - attention_mask=None, - is_causal=None, ): norm_out = self.norm_1(hidden_states) - attn_output, _, past_key_value = self.attn.forward( + attn_output, _, _ = self.attn.forward( hidden_states=norm_out, - past_key_value=past_key_value, - attention_mask=attention_mask, ) h = hidden_states.to(attn_output.device) + attn_output out = h + self.mlp.forward(self.norm_2(h)) - return out, None, past_key_value \ No newline at end of file + return out diff --git a/awq/modules/fused/cache.py b/awq/modules/fused/cache.py index 87943f59..301e4935 100644 --- a/awq/modules/fused/cache.py +++ b/awq/modules/fused/cache.py @@ -2,30 +2,39 @@ class WindowedCache: - def __init__(self, cache_v_shape, cache_k_shape, max_seq_len, device): + def __init__( + self, cache_batch_size, n_heads, n_kv_heads, head_dim, max_seq_len, device + ): """ The window size is the same as the max_seq_len. The window will automatically roll once max_seq_len is exceeded. """ - # [batch_size, n_kv_heads, max_seq_len, head_dim] - self.v = torch.zeros(cache_v_shape).to(device).half() - # [batch_size, n_kv_heads, head_dim // pack_factor, max_seq_len, pack_factor] - self.k = torch.zeros(cache_k_shape).to(device).half() + size = ( + cache_batch_size, + max_seq_len, + n_kv_heads if n_kv_heads != 0 else n_heads, + head_dim, + ) + self.v = torch.zeros( + size, + device=device, + dtype=torch.float16, + ) + self.k = torch.zeros( + size, + device=device, + dtype=torch.float16, + ) self.max_seq_len = max_seq_len - def get_kv(self, batch_size, start_pos, seqlen, head_dim): + def get_kv(self, batch_size, start_pos, seqlen): """ Gets the key-value store in correct shapes. + NOTE: This function is a legacy function. It is only available to showcase + how to accurately retrieve the KV-cache but is not currently used. """ - xv = ( - self.v[:batch_size, :, : start_pos + seqlen, :].transpose(1, 2).contiguous() - ) - xk = ( - self.k[:batch_size, :, :, : start_pos + seqlen, :] - .transpose(2, 3) - .contiguous() - ) - xk = xk.reshape(xk.shape[:-2] + (head_dim,)).transpose(1, 2).contiguous() + xv = self.v[:batch_size, : start_pos + seqlen] + xk = self.k[:batch_size, : start_pos + seqlen] return xv, xk @@ -33,8 +42,8 @@ def update_kv(self, values_store, keys_store, batch_size, start_pos, seqlen): """ Updates the values in the key-value store. """ - self.v[:batch_size, :, start_pos : start_pos + seqlen, :] = values_store - self.k[:batch_size, :, :, start_pos : start_pos + seqlen, :] = keys_store + self.v[:batch_size, start_pos : start_pos + seqlen, :, :] = values_store + self.k[:batch_size, start_pos : start_pos + seqlen, :, :] = keys_store def roll_kv_n_steps(self, start_pos, n=100): """ diff --git a/awq/modules/fused/model.py b/awq/modules/fused/model.py index 11e66ec5..b7912675 100644 --- a/awq/modules/fused/model.py +++ b/awq/modules/fused/model.py @@ -30,9 +30,6 @@ def __init__(self, vocab_size, blocks, embedding, norm): def forward( self, input_ids: torch.Tensor, - attn_bias=None, - attention_mask=None, - is_causal=None, *args, **kwargs, ): @@ -45,28 +42,15 @@ def forward( h = self.embedding(input_ids) - mask = fused_utils.prepare_attention_mask( - seqlen=seqlen, - start_pos=self.blocks[0].attn.start_pos, - device=input_ids.device, - type_as=h, - ) - for layer in self.blocks: - h, mask = fused_utils.prepare_correct_devices( - layer, - h, - mask, - ) - h, _, past_key_value = layer( - h, None, attention_mask=mask, is_causal=is_causal - ) + h = h.to(layer.device) + h = layer(h) h = self.norm(h) return MoeModelOutputWithPast( last_hidden_state=h, - past_key_values=past_key_value, + past_key_values=None, hidden_states=(), attentions=(), router_logits=(), @@ -99,9 +83,6 @@ def layers(self): def forward( self, input_ids: torch.Tensor, - attn_bias=None, - attention_mask=None, - is_causal=None, *args, **kwargs, ): @@ -114,20 +95,10 @@ def forward( h = self.embedding(input_ids) - mask = fused_utils.prepare_attention_mask( - seqlen=seqlen, - start_pos=self.blocks[0].attn.start_pos, - device=input_ids.device, - type_as=h, - ) - for layer in self.blocks: - h, mask = fused_utils.prepare_correct_devices( - layer, - h, - mask, - ) - h, _, _ = layer(h, None, attention_mask=mask, is_causal=is_causal) + h = h.to(layer.device) + h = layer(h) + h = self.norm(h) return BaseModelOutputWithPast( @@ -174,20 +145,10 @@ def forward( h = self.embedding(input_ids) - mask = fused_utils.prepare_attention_mask( - seqlen=seqlen, - start_pos=self.blocks[0].attn.start_pos, - device=input_ids.device, - type_as=h, - ) - for layer in self.blocks: - h, mask = fused_utils.prepare_correct_devices( - layer, - h, - mask, - ) - h, _, _ = layer(h, None, attention_mask=mask, is_causal=is_causal) + h = h.to(layer.device) + h = layer(h) + h = self.norm(h) return BaseModelOutputWithPast( @@ -213,9 +174,6 @@ def __init__(self, vocab_size, blocks, wte, norm_f): def forward( self, input_ids, - attn_bias=None, - attention_mask=None, - is_causal=None, *args, **kwargs, ): @@ -228,27 +186,15 @@ def forward( h = self.wte(input_ids) - mask = fused_utils.prepare_attention_mask( - seqlen=seqlen, - start_pos=self.blocks[0].attn.start_pos, - device=input_ids.device, - type_as=h, - ) - for layer in self.blocks: - h, mask = fused_utils.prepare_correct_devices( - layer, - h, - mask, - ) - h, _, past_key_value = layer( - h, None, attention_mask=mask, is_causal=is_causal - ) + h = h.to(layer.device) + h = layer(h) + h = self.norm_f(h) return BaseModelOutputWithPast( last_hidden_state=h, - past_key_values=past_key_value, + past_key_values=None, hidden_states=(), attentions=(), ) @@ -269,9 +215,6 @@ def __init__(self, vocab_size, blocks, word_embeddings, ln_f): def forward( self, input_ids, - attn_bias=None, - attention_mask=None, - is_causal=None, *args, **kwargs, ): @@ -284,31 +227,20 @@ def forward( h = self.word_embeddings(input_ids) - mask = fused_utils.prepare_attention_mask( - seqlen=seqlen, - start_pos=self.blocks[0].attn.start_pos, - device=input_ids.device, - type_as=h, - ) - for layer in self.blocks: - h, mask = fused_utils.prepare_correct_devices( - layer, - h, - mask, - ) - h, _, past_key_value = layer( - h, None, attention_mask=mask, is_causal=is_causal - ) + h = h.to(layer.device) + h = layer(h) + h = self.ln_f(h) return BaseModelOutputWithPast( last_hidden_state=h, - past_key_values=past_key_value, + past_key_values=None, hidden_states=(), attentions=(), ) + class Phi3Model(nn.Module): """ Phi3LikeModel is intended to be reused across models that have @@ -335,9 +267,6 @@ def layers(self): def forward( self, input_ids: torch.Tensor, - attn_bias=None, - attention_mask=None, - is_causal=None, *args, **kwargs, ): @@ -350,22 +279,10 @@ def forward( h = self.embedding(input_ids) - mask = fused_utils.prepare_attention_mask( - seqlen=seqlen, - start_pos=self.blocks[0].attn.start_pos, - device=input_ids.device, - type_as=h, - ) - for layer in self.blocks: - h, mask = fused_utils.prepare_correct_devices( - layer, - h, - mask, - ) - h, _, _ = layer( - h, None, attention_mask=mask, is_causal=is_causal - ) + h = h.to(layer.device) + h = layer(h) + h = self.norm(h) return BaseModelOutputWithPast( @@ -398,9 +315,6 @@ def layers(self): def forward( self, input_ids: torch.Tensor, - attn_bias=None, - attention_mask=None, - is_causal=None, *args, **kwargs, ): @@ -416,20 +330,10 @@ def forward( normalizer = torch.tensor(self.hidden_size**0.5, dtype=h.dtype) h = h * normalizer - mask = fused_utils.prepare_attention_mask( - seqlen=seqlen, - start_pos=self.blocks[0].attn.start_pos, - device=input_ids.device, - type_as=h, - ) - for layer in self.blocks: - h, mask = fused_utils.prepare_correct_devices( - layer, - h, - mask, - ) - h, _, _ = layer(h, None, attention_mask=mask, is_causal=is_causal) + h = h.to(layer.device) + h = layer(h) + h = self.norm(h) return BaseModelOutputWithPast( diff --git a/awq/modules/fused/moe.py b/awq/modules/fused/moe.py index 70431609..fbdae92f 100644 --- a/awq/modules/fused/moe.py +++ b/awq/modules/fused/moe.py @@ -91,7 +91,6 @@ def apply_moe_weights( return torch.sum(out, dim=1) - def moe_align_block_size(topk_ids: torch.Tensor, block_size: int, num_experts: int): """ Aligns the token distribution across experts to be compatible with block size for matrix multiplication. diff --git a/awq/modules/fused/norm.py b/awq/modules/fused/norm.py index 17710b08..51f8ca8a 100644 --- a/awq/modules/fused/norm.py +++ b/awq/modules/fused/norm.py @@ -31,6 +31,8 @@ def forward(self, x): "Please install them from https://github.com/casper-hansen/AutoAWQ_kernels" ) output = torch.empty_like(x) - awq_ext.layernorm_forward_cuda(x, self.weight, output, self.variance_epsilon) + awq_ext.layernorm_forward_cuda( + x, self.weight, output, self.variance_epsilon + ) return output diff --git a/awq/utils/fused_utils.py b/awq/utils/fused_utils.py index 577fc94d..8e35e06f 100644 --- a/awq/utils/fused_utils.py +++ b/awq/utils/fused_utils.py @@ -11,15 +11,6 @@ ) -def prepare_correct_devices(next_layer, hidden_states, mask): - hidden_states = hidden_states.to(next_layer.device) - - if mask is not None: - mask = mask.to(next_layer.device) - - return hidden_states, mask - - def prepare_cache(blocks, seqlen: int) -> int: for block in blocks: start_pos = block.attn.start_pos @@ -51,15 +42,6 @@ def prepare_input_ids(input_ids: torch.Tensor, last_forward_num_tokens: int): return input_ids, last_forward_num_tokens + num_new_tokens -def prepare_attention_mask(seqlen, start_pos, device, type_as: torch.Tensor): - mask = None - if seqlen > 1: - mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=device) - mask = torch.triu(mask, diagonal=start_pos + 1).type_as(type_as) - - return mask - - def fuse_qkv(module, q_proj, k_proj, v_proj): bias = ( torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) @@ -181,28 +163,13 @@ def fuse_linears(linears, device, dim=1, operation=torch.cat): def get_attention_shapes( - attention_shapes, max_seq_len, cache_batch_size, n_heads, n_kv_heads, head_dim + attention_shapes, n_heads, n_kv_heads, head_dim ): if attention_shapes is not None: attention_shapes = attention_shapes elif n_kv_heads == 0: attention_shapes = { - # following fastertransformer definition - "cache_v": ( - cache_batch_size, - n_heads, - max_seq_len, - head_dim, - ), - # 8: pack 8 fp16 in FT, if fp32 then use 4 - "cache_k": ( - cache_batch_size, - n_heads, - head_dim // 8, - max_seq_len, - 8, - ), "xqkv_view": (-1, n_heads, head_dim), "xq_slice": lambda xqkv: xqkv[:, :, 0], "xk_slice": lambda xqkv: xqkv[:, :, 1], @@ -218,21 +185,6 @@ def get_attention_shapes( else: attention_shapes = { - # following fastertransformer definition - "cache_v": ( - cache_batch_size, - n_kv_heads, - max_seq_len, - head_dim, - ), - # 8: pack 8 fp16 in FT, if fp32 then use 4 - "cache_k": ( - cache_batch_size, - n_kv_heads, - head_dim // 8, - max_seq_len, - 8, - ), "xqkv_view": (n_heads + n_kv_heads * 2, head_dim), "xq_slice": lambda xqkv: xqkv[:, :, 0:n_heads], "xk_slice": lambda xqkv: xqkv[:, :, n_heads : (n_heads + n_kv_heads)], diff --git a/setup.py b/setup.py index b86da218..b904a440 100644 --- a/setup.py +++ b/setup.py @@ -51,7 +51,7 @@ "eval": ["lm_eval==0.4.1", "tabulate", "protobuf", "evaluate", "scipy"], "dev": ["black", "mkdocstrings-python", "mkdocs-material", "griffe-typingdoc"], "cpu": ["intel-extension-for-pytorch>=2.4.0"], - "kernels": ["autoawq-kernels"], + "kernels": ["autoawq-kernels", "flash-attn>=2.2.0"], }, **common_setup_kwargs, )