Skip to content

Commit

Permalink
Remove past_key_value (save 2GB VRAM)
Browse files Browse the repository at this point in the history
  • Loading branch information
casper-hansen committed Oct 5, 2023
1 parent eccb8f9 commit 40e6952
Showing 1 changed file with 2 additions and 6 deletions.
8 changes: 2 additions & 6 deletions awq/modules/fused/attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,6 @@ def forward(
keys = torch.repeat_interleave(keys, dim=2, repeats=self.n_kv_groups)
values = torch.repeat_interleave(values, dim=2, repeats=self.n_kv_groups)

past_key_value = (xk, xv) if use_cache else None
xq = xq.transpose(1, 2)
keys = keys.transpose(1, 2)
values = values.transpose(1, 2)
Expand All @@ -222,14 +221,10 @@ def forward(
output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim)
attention_weight = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
else:
# xq = xq[:, 0, :, :]
# xk = xk[:, 0, :, :]
# xv = xv[:, 0, :, :]
xq = xq.view((bsz,) + self.attention_shapes["single_xq_view"])
xk = xk.view((bsz,) + self.attention_shapes["single_xk_view"])
xv = xv.view((bsz,) + self.attention_shapes["single_xv_view"])

past_key_value = (xk, xv) if use_cache else None
attention_weight = ft_inference_engine.single_query_attention(
xq, # query
xk, # key
Expand All @@ -252,4 +247,5 @@ def forward(
else:
self.start_pos = 0

return attn_output, attention_weight, past_key_value
# past_key_value is replaced with cache_v, cache_k, returning None
return attn_output, attention_weight, None

0 comments on commit 40e6952

Please sign in to comment.