Skip to content

Commit d08b78b

Browse files
authored
Properly initializing the new field in the attn metadata (#337)
1 parent 399016d commit d08b78b

17 files changed

+38
-9
lines changed

tests/kernels/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -914,6 +914,7 @@ def make_test_metadata(
914914
num_prefills=num_prefills,
915915
slot_mapping=(None if kv_mmap is None else kv_mmap.slot_mapping),
916916
multi_modal_placeholder_index_maps=None,
917+
enable_kv_scales_calculation=True,
917918
num_prefill_tokens=num_prefill_tokens,
918919
num_decode_tokens=num_decode_tokens,
919920
seq_lens=seq_lens,
@@ -963,6 +964,7 @@ def make_test_metadata(
963964
num_prefills=num_prefills,
964965
slot_mapping=kv_mmap.slot_mapping,
965966
multi_modal_placeholder_index_maps=None,
967+
enable_kv_scales_calculation=True,
966968
num_prefill_tokens=num_prefill_tokens,
967969
num_decode_tokens=num_decode_tokens,
968970
seq_lens=seq_lens,

tests/worker/test_model_input.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def test_model_runner_input():
7474
num_decode_tokens=3,
7575
slot_mapping=torch.zeros(1),
7676
multi_modal_placeholder_index_maps=None,
77+
enable_kv_scales_calculation=True,
7778
)
7879
model_input = ModelInputForGPUWithSamplingMetadata(
7980
input_tokens=torch.ones(10),
@@ -126,6 +127,7 @@ def test_embedding_model_runner_input():
126127
num_decode_tokens=3,
127128
slot_mapping=torch.zeros(1),
128129
multi_modal_placeholder_index_maps=None,
130+
enable_kv_scales_calculation=True,
129131
)
130132
model_input = ModelInputForGPUWithPoolingMetadata(
131133
input_tokens=torch.ones(10),
@@ -177,6 +179,7 @@ def test_multi_step_model_runner_input():
177179
num_decode_tokens=3,
178180
slot_mapping=torch.zeros(1),
179181
multi_modal_placeholder_index_maps=None,
182+
enable_kv_scales_calculation=True,
180183
)
181184
frozen_model_input = ModelInputForGPUWithSamplingMetadata(
182185
input_tokens=torch.ones(10),

vllm/attention/backends/abstract.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from abc import ABC, abstractmethod
22
from contextlib import contextmanager
3-
from dataclasses import dataclass, field, fields
3+
from dataclasses import dataclass, fields
44
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set,
55
Tuple, Type, TypeVar)
66

@@ -126,8 +126,7 @@ class AttentionMetadata:
126126

127127
# Enable/disable KV scales calculation. This is so that we can disable the
128128
# calculation until after prefill and cuda graph capture.
129-
enable_kv_scales_calculation: bool = field(init=False,
130-
default_factory=lambda: True)
129+
enable_kv_scales_calculation: bool
131130

132131
@property
133132
@abstractmethod

vllm/attention/backends/blocksparse_attn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ def prefill_metadata(
222222
slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
223223
multi_modal_placeholder_index_maps=self.
224224
multi_modal_placeholder_index_maps,
225+
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
225226
seq_lens=self.seq_lens[:self.num_prefills],
226227
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
227228
max_query_len=self.max_query_len,
@@ -251,6 +252,7 @@ def decode_metadata(self) -> Optional["BlocksparseFlashAttentionMetadata"]:
251252
num_decode_tokens=self.num_decode_tokens,
252253
slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
253254
multi_modal_placeholder_index_maps=None,
255+
enable_kv_scales_calculation=False,
254256
seq_lens=None,
255257
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
256258
max_query_len=None,

vllm/attention/backends/flash_attn.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]:
224224
slot_mapping=slot_mapping,
225225
multi_modal_placeholder_index_maps=self.
226226
multi_modal_placeholder_index_maps,
227+
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
227228
seq_lens=seq_lens,
228229
seq_lens_tensor=seq_lens_tensor,
229230
max_query_len=self.max_query_len,
@@ -268,6 +269,7 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]:
268269
num_decode_tokens=self.num_decode_tokens,
269270
slot_mapping=slot_mapping,
270271
multi_modal_placeholder_index_maps=None,
272+
enable_kv_scales_calculation=True,
271273
seq_lens=None,
272274
seq_lens_tensor=seq_lens_tensor,
273275
max_decode_query_len=self.max_decode_query_len,
@@ -550,6 +552,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
550552
num_decode_tokens=num_decode_tokens,
551553
seq_lens=seq_lens,
552554
multi_modal_placeholder_index_maps=placeholder_index_maps,
555+
enable_kv_scales_calculation=True,
553556
seq_lens_tensor=seq_lens_tensor,
554557
max_query_len=max_query_len,
555558
max_decode_query_len=max_decode_query_len,

vllm/attention/backends/flashinfer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ def graph_capture_get_metadata_for_batch(
218218
num_prefills=0,
219219
slot_mapping=self._graph_slot_mapping[:batch_size],
220220
multi_modal_placeholder_index_maps=None,
221+
enable_kv_scales_calculation=False,
221222
num_prefill_tokens=0,
222223
num_decode_tokens=batch_size,
223224
max_prefill_seq_len=0,
@@ -711,6 +712,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
711712
num_prefills=self.num_prefills,
712713
slot_mapping=slot_mapping_tensor,
713714
multi_modal_placeholder_index_maps=placeholder_index_maps,
715+
enable_kv_scales_calculation=False,
714716
num_prefill_tokens=self.num_prefill_tokens,
715717
num_decode_tokens=num_decode_tokens,
716718
max_prefill_seq_len=max_prefill_seq_len,

vllm/attention/backends/placeholder_attn.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
140140
slot_mapping=slot_mapping,
141141
multi_modal_placeholder_index_maps=self.
142142
multi_modal_placeholder_index_maps,
143+
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
143144
seq_lens=self.seq_lens[:self.num_prefills],
144145
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
145146
max_decode_query_len=0,
@@ -173,6 +174,7 @@ def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
173174
num_decode_tokens=self.num_decode_tokens,
174175
slot_mapping=slot_mapping,
175176
multi_modal_placeholder_index_maps=None,
177+
enable_kv_scales_calculation=True,
176178
seq_lens=None,
177179
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
178180
max_decode_query_len=self.max_decode_query_len,
@@ -378,6 +380,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
378380
num_prefills=self.num_prefills,
379381
slot_mapping=slot_mapping,
380382
multi_modal_placeholder_index_maps=placeholder_index_maps,
383+
enable_kv_scales_calculation=True,
381384
num_prefill_tokens=self.num_prefill_tokens,
382385
num_decode_tokens=num_decode_tokens,
383386
seq_lens=seq_lens,

vllm/attention/backends/rocm_flash_attn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
165165
slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
166166
multi_modal_placeholder_index_maps=self.
167167
multi_modal_placeholder_index_maps,
168+
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
168169
seq_lens=self.seq_lens[:self.num_prefills],
169170
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
170171
max_query_len=self.max_query_len,
@@ -202,6 +203,7 @@ def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
202203
num_decode_tokens=self.num_decode_tokens,
203204
slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
204205
multi_modal_placeholder_index_maps=None,
206+
enable_kv_scales_calculation=True,
205207
seq_lens=None,
206208
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
207209
max_query_len=None,

vllm/attention/backends/torch_sdpa.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
372372
prefill_block_tables=prefill_block_tables,
373373
slot_mapping=slot_mapping,
374374
multi_modal_placeholder_index_maps=placeholder_index_maps,
375+
enable_kv_scales_calculation=False,
375376
)
376377

377378
return attn_metadata

vllm/attention/backends/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
274274
num_prefills=self.num_prefills,
275275
slot_mapping=slot_mapping_tensor,
276276
multi_modal_placeholder_index_maps=placeholder_index_maps,
277+
enable_kv_scales_calculation=True,
277278
num_prefill_tokens=self.num_prefill_tokens,
278279
num_decode_tokens=num_decode_tokens,
279280
seq_lens=seq_lens,
@@ -326,6 +327,7 @@ def graph_capture_get_metadata_for_batch(
326327
num_decode_tokens=batch_size,
327328
slot_mapping=self._graph_slot_mapping[:batch_size],
328329
multi_modal_placeholder_index_maps=None,
330+
enable_kv_scales_calculation=True,
329331
seq_lens=None,
330332
seq_lens_tensor=self._graph_seq_lens[:batch_size],
331333
max_query_len=1,

vllm/attention/backends/xformers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]:
217217
slot_mapping=slot_mapping,
218218
multi_modal_placeholder_index_maps=self.
219219
multi_modal_placeholder_index_maps,
220+
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
220221
seq_lens=seq_lens,
221222
seq_lens_tensor=seq_lens_tensor,
222223
max_query_len=self.max_query_len,
@@ -261,6 +262,7 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]:
261262
num_decode_tokens=self.num_decode_tokens,
262263
slot_mapping=slot_mapping,
263264
multi_modal_placeholder_index_maps=None,
265+
enable_kv_scales_calculation=True,
264266
seq_lens_tensor=seq_lens_tensor,
265267
max_prefill_seq_len=0,
266268
max_decode_seq_len=self.max_decode_seq_len,

vllm/worker/hpu_model_runner.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -892,7 +892,8 @@ def _prepare_prompt(
892892
num_decode_tokens=0,
893893
slot_mapping=slot_mapping,
894894
multi_modal_placeholder_index_maps=
895-
None # FIXME(kzawora): mutli-modality will not work here
895+
None, # FIXME(kzawora): mutli-modality will not work here
896+
enable_kv_scales_calculation=False,
896897
)
897898
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
898899

@@ -1046,7 +1047,9 @@ def _prepare_decode(
10461047
num_prefill_tokens=0,
10471048
num_decode_tokens=num_decode_tokens,
10481049
slot_mapping=slot_mapping,
1049-
multi_modal_placeholder_index_maps=None)
1050+
multi_modal_placeholder_index_maps=None,
1051+
enable_kv_scales_calculation=False,
1052+
)
10501053
return PrepareDecodeMetadata(input_tokens=input_tokens,
10511054
input_positions=input_positions,
10521055
attn_metadata=attn_metadata,

vllm/worker/model_runner.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,6 @@ def from_broadcasted_tensor_dict(
174174
if attn_backend is not None:
175175
tensor_dict = _init_attn_metadata_from_tensor_dict(
176176
attn_backend, tensor_dict)
177-
if "enable_kv_scales_calculation" in tensor_dict:
178-
tensor_dict.pop("enable_kv_scales_calculation")
179177
return cls(**tensor_dict)
180178

181179

vllm/worker/model_runner_base.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,7 @@ def _init_attn_metadata_from_tensor_dict(
4747
# Extract the fields used to create AttentionMetadata.
4848
valid_attn_kwargs = {}
4949
for field in dataclasses.fields(attn_backend.get_metadata_cls()):
50-
if field.name in tensor_dict and field.name != \
51-
'enable_kv_scales_calculation':
50+
if field.name in tensor_dict:
5251
valid_attn_kwargs[field.name] = tensor_dict.pop(field.name)
5352

5453
attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs)

vllm/worker/openvino_model_runner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,7 @@ def _prepare_model_input(
278278
block_indices_begins=block_indices_begins_tensor,
279279
max_context_len=max_context_len_tensor,
280280
multi_modal_placeholder_index_maps=placeholder_index_maps,
281+
enable_kv_scales_calculation=False,
281282
)
282283

283284
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)

vllm/worker/tpu_model_runner.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ def _dummy_run(
184184
num_decode_tokens=0,
185185
slot_mapping=slot_mapping,
186186
multi_modal_placeholder_index_maps=None,
187+
enable_kv_scales_calculation=False,
187188
block_tables=None,
188189
context_lens=None,
189190
effective_query_lens=None,
@@ -202,6 +203,7 @@ def _dummy_run(
202203
num_decode_tokens=0,
203204
slot_mapping=slot_mapping,
204205
multi_modal_placeholder_index_maps=None,
206+
enable_kv_scales_calculation=False,
205207
block_tables=block_tables,
206208
context_lens=context_lens,
207209
effective_query_lens=effective_query_lens,
@@ -233,6 +235,7 @@ def _dummy_run(
233235
num_decode_tokens=batch_size * seq_len,
234236
slot_mapping=slot_mapping,
235237
multi_modal_placeholder_index_maps=None,
238+
enable_kv_scales_calculation=False,
236239
block_tables=block_tables,
237240
context_lens=context_lens,
238241
)
@@ -418,6 +421,7 @@ def _prepare_prompt(
418421
num_decode_tokens=0,
419422
slot_mapping=slot_mapping,
420423
multi_modal_placeholder_index_maps=None,
424+
enable_kv_scales_calculation=False,
421425
block_tables=block_tables,
422426
context_lens=context_lens,
423427
effective_query_lens=prompt_lens,
@@ -489,6 +493,7 @@ def _prepare_decode(
489493
num_decode_tokens=batch_size,
490494
slot_mapping=slot_mapping,
491495
multi_modal_placeholder_index_maps=None,
496+
enable_kv_scales_calculation=False,
492497
block_tables=block_tables,
493498
context_lens=context_lens,
494499
)

vllm/worker/xpu_model_runner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,7 @@ def _prepare_prompt(
257257
is_prompt=True,
258258
slot_mapping=slot_mapping,
259259
multi_modal_placeholder_index_maps=placeholder_index_maps,
260+
enable_kv_scales_calculation=False,
260261
seq_lens=seq_lens,
261262
seqlen_q=seqlen_q,
262263
max_seqlen=max_seqlen,
@@ -341,6 +342,7 @@ def _prepare_decode(
341342
is_prompt=False,
342343
slot_mapping=slot_mapping,
343344
multi_modal_placeholder_index_maps=None,
345+
enable_kv_scales_calculation=False,
344346
seq_lens=seq_lens,
345347
seqlen_q=torch.tensor([]),
346348
max_seqlen=0,

0 commit comments

Comments
 (0)