diff --git a/awq/modules/fused/attn.py b/awq/modules/fused/attn.py index 332b7163..73bedd26 100644 --- a/awq/modules/fused/attn.py +++ b/awq/modules/fused/attn.py @@ -219,7 +219,7 @@ def forward( xv = xv.view((bsz,) + self.attention_shapes["single_xv_view"]) past_key_value = (xk, xv) if use_cache else None - attention_weight = awq_inference_engine.single_query_attention( + attention_weight = ft_inference_engine.single_query_attention( xq, # query xk, # key xv, # value