From 8a7cc254a064b8d42bf4de7a9c3f29552240dfd9 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Wed, 15 May 2024 11:52:45 +0900 Subject: [PATCH 1/4] Revert "[Kernel] Use flash-attn for decoding (#3648)" (#4820) Lora 3 & 4 test seems to have illegal memory access failure after this commit; [2024-05-14 23:51:18,182 E 22 22] logging.cc:101: Unhandled exception: N3c105ErrorE. what(): CUDA error: an illegal memory access was encountered
Exmaple: https://buildkite.com/vllm/ci/builds/7382#018f793d-1527-4e1c-ab59-c3a34ec55241 This reverts commit 1356df5. FILL IN THE PR DESCRIPTION HERE FIX #xxxx (link existing issues this PR will resolve) --- tests/kernels/test_flash_attn.py | 209 -------------------------- tests/models/test_big_models.py | 2 +- tests/models/test_fp8.py | 10 +- vllm/attention/backends/flash_attn.py | 128 +++++++--------- vllm/attention/selector.py | 14 -- vllm/worker/model_runner.py | 15 +- 6 files changed, 65 insertions(+), 313 deletions(-) delete mode 100644 tests/kernels/test_flash_attn.py diff --git a/tests/kernels/test_flash_attn.py b/tests/kernels/test_flash_attn.py deleted file mode 100644 index 89bdacc67fbc4..0000000000000 --- a/tests/kernels/test_flash_attn.py +++ /dev/null @@ -1,209 +0,0 @@ -from typing import List, Optional, Tuple - -import pytest -import torch -from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache - -NUM_HEADS = [(16, 16), (32, 8), (64, 8)] -HEAD_SIZES = [128, 256] -BLOCK_SIZES = [16, 32] -DTYPES = [torch.float16, torch.bfloat16] - - -def ref_paged_attn( - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - query_lens: List[int], - kv_lens: List[int], - block_tables: torch.Tensor, - scale: float, - sliding_window: Optional[int] = None, -) -> torch.Tensor: - num_seqs = len(query_lens) - block_tables = block_tables.cpu().numpy() - _, block_size, num_kv_heads, head_size = key_cache.shape - - outputs = [] - start_idx = 0 - for i in range(num_seqs): - query_len = query_lens[i] - kv_len = kv_lens[i] - q = query[start_idx:start_idx + query_len] - q *= scale - - num_kv_blocks = (kv_len + block_size - 1) // block_size - block_indices = block_tables[i, :num_kv_blocks] - - k = key_cache[block_indices].view(-1, num_kv_heads, head_size) - k = k[:kv_len] - v = value_cache[block_indices].view(-1, num_kv_heads, head_size) - v = v[:kv_len] - - if q.shape[1] != k.shape[1]: - k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1) - v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1) - attn = torch.einsum("qhd,khd->hqk", q, k).float() - empty_mask = torch.ones(query_len, kv_len) - mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool() - if sliding_window is not None: - sliding_window_mask = torch.triu(empty_mask, - diagonal=kv_len - - (query_len + sliding_window) + - 1).bool().logical_not() - mask |= sliding_window_mask - attn.masked_fill_(mask, float("-inf")) - attn = torch.softmax(attn, dim=-1).to(v.dtype) - out = torch.einsum("hqk,khd->qhd", attn, v) - - outputs.append(out) - start_idx += query_len - - return torch.cat(outputs, dim=0) - - -@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]]) -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("dtype", DTYPES) -@torch.inference_mode -def test_flash_attn_with_paged_kv( - kv_lens: List[Tuple[int, int]], - num_heads: Tuple[int, int], - head_size: int, - dtype: torch.dtype, - block_size: int, -) -> None: - torch.set_default_device("cuda") - torch.cuda.manual_seed_all(0) - num_blocks = 128 - num_seqs = len(kv_lens) - num_query_heads = num_heads[0] - num_kv_heads = num_heads[1] - assert num_query_heads % num_kv_heads == 0 - max_kv_len = max(kv_lens) - scale = head_size**-0.5 - - query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype) - key_cache = torch.randn(num_blocks, - block_size, - num_kv_heads, - head_size, - dtype=dtype) - value_cache = torch.randn_like(key_cache) - kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32) - - max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size - block_tables = torch.randint(0, - num_blocks, - (num_seqs, max_num_blocks_per_seq), - dtype=torch.int32) - - output = flash_attn_with_kvcache( - q=query.unsqueeze(1), - k_cache=key_cache, - v_cache=value_cache, - softmax_scale=scale, - causal=True, - block_table=block_tables, - cache_seqlens=kv_lens_tensor, - ).squeeze(1) - - ref_output = ref_paged_attn( - query=query, - key_cache=key_cache, - value_cache=value_cache, - query_lens=[1] * num_seqs, - kv_lens=kv_lens, - block_tables=block_tables, - scale=scale, - ) - assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \ - f"{torch.max(torch.abs(output - ref_output))}" - - -@pytest.mark.parametrize("seq_lens", [[(1, 1328), (5, 18), (129, 463)]]) -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("sliding_window", [None]) -@pytest.mark.parametrize("dtype", DTYPES) -@torch.inference_mode -def test_varlen_with_paged_kv( - seq_lens: List[Tuple[int, int]], - num_heads: Tuple[int, int], - head_size: int, - sliding_window: Optional[int], - dtype: torch.dtype, - block_size: int, -) -> None: - torch.set_default_device("cuda") - torch.cuda.manual_seed_all(0) - num_blocks = 128 - num_seqs = len(seq_lens) - query_lens = [x[0] for x in seq_lens] - kv_lens = [x[1] for x in seq_lens] - num_query_heads = num_heads[0] - num_kv_heads = num_heads[1] - assert num_query_heads % num_kv_heads == 0 - max_query_len = max(query_lens) - max_kv_len = max(kv_lens) - window_size = ((sliding_window, - sliding_window) if sliding_window is not None else - (-1, -1)) - scale = head_size**-0.5 - - query = torch.randn(sum(query_lens), - num_query_heads, - head_size, - dtype=dtype) - key_cache = torch.randn(num_blocks, - block_size, - num_kv_heads, - head_size, - dtype=dtype) - value_cache = torch.randn_like(key_cache) - # Normalize the scale of the key and value caches to mitigate - # numerical instability. - key_cache /= head_size**0.5 - value_cache /= head_size**0.5 - cu_query_lens = torch.tensor([0] + query_lens, - dtype=torch.int32).cumsum(dim=0, - dtype=torch.int32) - cu_kv_lens = torch.tensor([0] + kv_lens, - dtype=torch.int32).cumsum(dim=0, - dtype=torch.int32) - - max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size - block_tables = torch.randint(0, - num_blocks, - (num_seqs, max_num_blocks_per_seq), - dtype=torch.int32) - - output = flash_attn_varlen_func( - q=query, - k=key_cache, - v=value_cache, - cu_seqlens_q=cu_query_lens, - cu_seqlens_k=cu_kv_lens, - max_seqlen_q=max_query_len, - max_seqlen_k=max_kv_len, - softmax_scale=scale, - causal=True, - window_size=window_size, - block_table=block_tables, - ) - - ref_output = ref_paged_attn( - query=query, - key_cache=key_cache, - value_cache=value_cache, - query_lens=query_lens, - kv_lens=kv_lens, - block_tables=block_tables, - scale=scale, - sliding_window=sliding_window, - ) - assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \ - f"{torch.max(torch.abs(output - ref_output))}" diff --git a/tests/models/test_big_models.py b/tests/models/test_big_models.py index 10e7c64e34e75..c02204f16ac68 100644 --- a/tests/models/test_big_models.py +++ b/tests/models/test_big_models.py @@ -12,7 +12,7 @@ # "Deci/DeciLM-7b", # Broken # "tiiuae/falcon-7b", # Broken "EleutherAI/gpt-j-6b", - # "mosaicml/mpt-7b", # Broken + "mosaicml/mpt-7b", # "Qwen/Qwen1.5-0.5B" # Broken, ] diff --git a/tests/models/test_fp8.py b/tests/models/test_fp8.py index 664e951a89f2a..e87a1783a83f1 100644 --- a/tests/models/test_fp8.py +++ b/tests/models/test_fp8.py @@ -25,18 +25,18 @@ 'LLaMA is a high-throughput and memory-efficient inference and serving engine for Large Language Models (', 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', 'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.', - 'A neural network is a complex system modeled after the human brain, consisting of interconnected nodes or "ne', - 'Zeta-5, a highly advanced robot designed for menial labor, whirred to a', - 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The', + 'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne', + 'Zeta-5, a highly advanced robot designed for menial labor, whirred and beep', + 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. Here', 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', - 'Here are the translations:\n\n**Japanese:** (Haya aki no tori, guri o', + 'Here are the translations:\n\n**Japanese:** (Haya tori, nemuri nemuri)\n\n**' ], "meta-llama/Meta-Llama-3-8B-Instruct": [ 'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained', 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', 'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.', 'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne', - 'In the vast, sterile laboratory, Robot 3456-Alpha, or "Alpha" for short', + 'In the year 2154, the robotics lab at NeuroSpark Industries was on the cusp of', 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The', 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', 'Here are the translations:\n\n**Japanese:** (Haya aki wa mushi o tsukamu' diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 11ecb2792ea9d..f59715bd76ede 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -1,16 +1,20 @@ -"""Attention layer with FlashAttention.""" +"""Attention layer with Flash and PagedAttention. + +NOTE(woosuk): At the moment, this file includes a lot of duplicated code from +XFormers backend. The duplicated code will be removed once we use flash-attn or +flashinfer for all the attention operations. +""" from dataclasses import dataclass from typing import List, Optional, Tuple, Type import torch -from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache +from vllm_flash_attn import flash_attn_varlen_func -from vllm._C import cache_ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionMetadataPerStage) - -_SUPPORTED_HEAD_SIZES = [32, 64, 96, 128, 160, 192, 224, 256] +from vllm.attention.ops.paged_attn import (PagedAttention, + PagedAttentionMetadata) class FlashAttentionBackend(AttentionBackend): @@ -34,9 +38,8 @@ def get_kv_cache_shape( num_kv_heads: int, head_size: int, ) -> Tuple[int, ...]: - if block_size % 16 != 0: - raise ValueError("Block size must be a multiple of 16.") - return (2, num_blocks, block_size, num_kv_heads, head_size) + return PagedAttention.get_kv_cache_shape(num_blocks, block_size, + num_kv_heads, head_size) @staticmethod def swap_blocks( @@ -44,26 +47,19 @@ def swap_blocks( dst_kv_cache: torch.Tensor, src_to_dst: torch.Tensor, ) -> None: - src_key_cache = src_kv_cache[0] - dst_key_cache = dst_kv_cache[0] - cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) - - src_value_cache = src_kv_cache[1] - dst_value_cache = dst_kv_cache[1] - cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) + PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) @staticmethod def copy_blocks( kv_caches: List[torch.Tensor], src_to_dists: torch.Tensor, ) -> None: - key_caches = [kv_cache[0] for kv_cache in kv_caches] - value_caches = [kv_cache[1] for kv_cache in kv_caches] - cache_ops.copy_blocks(key_caches, value_caches, src_to_dists) + PagedAttention.copy_blocks(kv_caches, src_to_dists) @dataclass -class FlashAttentionMetadata(AttentionMetadataPerStage): +class FlashAttentionMetadata(AttentionMetadataPerStage, + PagedAttentionMetadata): """Metadata for FlashAttentionBackend. NOTE: Any python object stored here is not updated when it is @@ -109,14 +105,6 @@ class FlashAttentionMetadata(AttentionMetadataPerStage): # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. use_cuda_graph: bool - # (batch_size, max_blocks_per_seq). - # Block addresses per sequence. (Seq id -> list of physical block) - # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks - # in the kv cache. Each block can contain up to block_size tokens. - # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph - # captured. - block_tables: Optional[torch.Tensor] - class FlashAttentionImpl(AttentionImpl): """ @@ -168,15 +156,11 @@ def __init__( assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads - if sliding_window is not None: - # NOTE(woosuk): flash-attn's sliding window does not work with - # paged KV cache. - raise ValueError( - "Sliding window is not supported in FlashAttention.") - if head_size not in _SUPPORTED_HEAD_SIZES: + suppored_head_sizes = PagedAttention.get_supported_head_sizes() + if head_size not in suppored_head_sizes: raise ValueError( - f"Head size {head_size} is not supported by FlashAttention. " - f"Supported head sizes are: {_SUPPORTED_HEAD_SIZES}.") + f"Head size {head_size} is not supported by PagedAttention. " + f"Supported head sizes are: {suppored_head_sizes}.") def forward( self, @@ -187,20 +171,17 @@ def forward( attn_metadata: AttentionMetadata[FlashAttentionMetadata], kv_scale: float = 1.0, ) -> torch.Tensor: - """Forward pass with FlashAttention. + """Forward pass with FlashAttention and PagedAttention. Args: query: shape = [num_tokens, num_heads * head_size] key: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size] - kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] + kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] """ - # NOTE(woosuk): FlashAttention does not support FP8 KV cache. - assert kv_scale == 1.0, "kv_scale is not supported in FlashAttention." - num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) @@ -208,20 +189,16 @@ def forward( value = value.view(-1, self.num_kv_heads, self.head_size) if kv_cache is not None: - key_cache = kv_cache[0] - value_cache = kv_cache[1] + key_cache, value_cache = PagedAttention.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) # Reshape the input keys and values and store them in the cache. # If kv_cache is not provided, the new key and value tensors are # not cached. This happens during the initial memory profiling run. - cache_ops.reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping.flatten(), - self.kv_cache_dtype, - ) + PagedAttention.write_to_paged_cache(key, value, key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, kv_scale) num_prefill_tokens = attn_metadata.num_prefill_tokens num_decode_tokens = attn_metadata.num_decode_tokens @@ -241,8 +218,7 @@ def forward( if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. - if (kv_cache is None or prefill_meta.block_tables is None - or prefill_meta.block_tables.numel() == 0): + if kv_cache is None or prefill_meta.block_tables.numel() == 0: # normal attention # When block_tables are not filled, it means q and k are the # prompt, and they have the same length. @@ -263,32 +239,38 @@ def forward( output[:num_prefill_tokens] = out else: # prefix-enabled attention - output[:num_prefill_tokens] = flash_attn_varlen_func( - q=query, - k=key_cache, - v=value_cache, - cu_seqlens_q=prefill_meta.subquery_start_loc, - max_seqlen_q=prefill_meta.max_query_len, - cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_k=prefill_meta.max_seq_len, - softmax_scale=self.scale, - causal=True, - alibi_slopes=self.alibi_slopes, - block_table=prefill_meta.block_tables, + # TODO(Hai) this triton kernel has regression issue (broke) to + # deal with different data types between KV and FP8 KV cache, + # to be addressed separately. + output[:num_prefill_tokens] = PagedAttention.forward_prefix( + query, + key, + value, + key_cache, + value_cache, + prefill_meta.block_tables, + prefill_meta.subquery_start_loc, + prefill_meta.seq_lens_tensor, + prefill_meta.context_lens_tensor, + prefill_meta.max_query_len, + self.alibi_slopes, + self.sliding_window[0], ) - if decode_meta := attn_metadata.decode_metadata: # Decoding run. - output[num_prefill_tokens:] = flash_attn_with_kvcache( - decode_query.unsqueeze(1), + output[num_prefill_tokens:] = PagedAttention.forward_decode( + decode_query, key_cache, value_cache, - block_table=decode_meta.block_tables, - cache_seqlens=decode_meta.seq_lens_tensor, - softmax_scale=self.scale, - causal=True, - alibi_slopes=self.alibi_slopes, - ).squeeze(1) + decode_meta.block_tables, + decode_meta.seq_lens_tensor, + decode_meta.max_seq_len, + self.kv_cache_dtype, + self.num_kv_heads, + self.scale, + self.alibi_slopes, + kv_scale, + ) # Reshape the output tensor. return output.view(num_tokens, hidden_size) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 5140c3cc86a31..06f99718a4dee 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -93,20 +93,6 @@ def _which_attn_to_use( "torch.float16 or torch.bfloat16.") return _Backend.XFORMERS - if kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"): - logger.info("Cannot use FlashAttention-2 backend for FP8 KV cache.") - return _Backend.XFORMERS - - if block_size % 16 != 0: - logger.info("Cannot use FlashAttention-2 backend for block size not " - "divisible by 16.") - return _Backend.XFORMERS - - if sliding_window is not None: - logger.info( - "Cannot use FlashAttention-2 backend due to sliding window.") - return _Backend.XFORMERS - try: import vllm_flash_attn # noqa: F401 except ImportError: diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 3f7e87c1de48c..b5e1991717b13 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -266,27 +266,20 @@ def _prepare_prompt( # Prefix is not supported with sliding_window context_len = len(computed_block_nums) * self.block_size prompt_tokens = prompt_tokens[context_len:] - if self.attn_backend.get_name() == "flash-attn": - # NOTE(woosuk): For flash-attn, the block table should - # include the entries for the incoming prefill tokens. - # TODO(woosuk): This is a temporary fix. We should - # provide a unified interface for different backends. - block_table = seq_group_metadata.block_tables[seq_id] - else: - block_table = computed_block_nums + prefix_block_tables.append(computed_block_nums) elif self.scheduler_config.chunked_prefill_enabled: if seq_group_metadata.block_tables is not None: # Prefill has chunked before. block_table = seq_group_metadata.block_tables[seq_id] + prefix_block_tables.append(block_table) else: # The first prefill. - block_table = [] + prefix_block_tables.append([]) else: - block_table = [] + prefix_block_tables.append([]) # Right now, prefill start is always 0. However, this # assumption can be changed once chunked prefill is introduced. assert context_len == 0 - prefix_block_tables.append(block_table) # actual prompt lens context_lens.append(context_len) From 65bf2ac165734fb6339210c4b2b8ce68d2391b77 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Wed, 15 May 2024 14:00:10 +0900 Subject: [PATCH 2/4] [Core][2/N] Model runner refactoring part 2. Combine prepare prefill / decode to a single API (#4681) This PR combines prepare_prompt and prepare_decode into a single API. This PR also coelsce the attn metadata for prefill/decode to a single class and allow to slice them when running attn backend. It also refactors subquery_start_loc which was not refactored in the previous PR --- tests/worker/test_model_runner.py | 123 ++- vllm/attention/__init__.py | 5 +- vllm/attention/backends/abstract.py | 68 +- vllm/attention/backends/flash_attn.py | 95 ++- vllm/attention/backends/flashinfer.py | 38 +- vllm/attention/backends/rocm_flash_attn.py | 98 ++- vllm/attention/backends/torch_sdpa.py | 28 +- vllm/attention/backends/xformers.py | 92 ++- vllm/attention/layer.py | 5 +- vllm/attention/ops/paged_attn.py | 10 +- vllm/engine/arg_utils.py | 5 + .../layers/rejection_sampler.py | 1 + vllm/sequence.py | 3 +- vllm/spec_decode/batch_expansion.py | 23 +- vllm/spec_decode/multi_step_worker.py | 1 + vllm/worker/cpu_model_runner.py | 10 +- vllm/worker/embedding_model_runner.py | 130 +-- vllm/worker/model_runner.py | 772 ++++++++---------- 18 files changed, 777 insertions(+), 730 deletions(-) diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index c2d1c5769619b..92de545acd53d 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -58,19 +58,25 @@ def test_prepare_prompt(batch_size): expected_selected_token_indices.append(selected_token_start_idx + seq_len - 1) selected_token_start_idx += seq_len - (input_tokens, input_positions, attn_metadata, return_seq_lens, _, _, _, _, - _, slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list)) + model_input = model_runner._prepare_model_input(seq_group_metadata_list) + input_tokens = model_input.input_tokens + input_positions = model_input.input_positions + attn_metadata = model_input.attn_metadata + return_seq_lens = model_input.seq_lens + slot_mapping = model_input.slot_mapping assert return_seq_lens == seq_lens assert len(slot_mapping) == len(input_tokens) # Verify input metadata is correct for prompts. device = model_runner.device - assert attn_metadata.is_prompt is True + assert attn_metadata.num_prefills > 0 + assert attn_metadata.num_decode_tokens == 0 assert torch.allclose( attn_metadata.seq_lens_tensor, torch.tensor(seq_lens, device=device, dtype=torch.int)) assert attn_metadata.seq_lens == seq_lens - assert attn_metadata.max_seq_len == max(seq_lens) + assert attn_metadata.max_prefill_seq_len == max(seq_lens) + assert attn_metadata.max_decode_seq_len == 0 # Test subquery start locs. start_idx = 0 @@ -79,11 +85,11 @@ def test_prepare_prompt(batch_size): start_idx += seq_len start_loc.append(start_idx) assert torch.allclose( - attn_metadata.subquery_start_loc, + attn_metadata.query_start_loc, torch.tensor(start_loc, dtype=torch.int32, device=device)) # Test seq start locs. Note that for normal prefill it is - # equivalent to subquery_start_loc. + # equivalent to query_start_loc. start_idx = 0 seq_start_loc = [start_idx] for seq_len in seq_lens: @@ -123,7 +129,7 @@ def test_prepare_prompt(batch_size): device=actual.device, dtype=actual.dtype) torch.testing.assert_close(actual, expected) - assert input_tokens == input_positions + torch.allclose(input_tokens, input_positions) actual = sampling_metadata.selected_token_indices expected = torch.tensor(expected_selected_token_indices, @@ -144,14 +150,18 @@ def test_prepare_decode_cuda_graph(batch_size): enable_chunked_prefill=False, ) - seq_lens = [] + context_lens = [] seq_group_metadata_list = [] + # Assume each seq group finishes prefill. for i in range(batch_size): # make sure all tokens fit into one block - seq_len = i % (model_runner.block_size - 1) + 1 - seq_lens.append(seq_len) - seq_data = list(range(seq_len)) + context_len = i % (model_runner.block_size - 1) + 1 + context_lens.append(context_len) + seq_data = list(range(context_len)) seq_data = SequenceData(seq_data) + seq_data.update_num_computed_tokens(context_len) + # Append one token ID since prefill is finished. + seq_data.append_token_id(1, 0) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=False, @@ -162,18 +172,45 @@ def test_prepare_decode_cuda_graph(batch_size): assert seq_group_metadata.token_chunk_size == 1 seq_group_metadata_list.append(seq_group_metadata) - input_tokens, input_positions, attn_metadata, _, _, _, slot_mapping = ( - model_runner._prepare_decode(seq_group_metadata_list)) + model_input = model_runner._prepare_model_input(seq_group_metadata_list) + input_tokens, input_positions, attn_metadata, slot_mapping = ( + model_input.input_tokens, model_input.input_positions, + model_input.attn_metadata, model_input.slot_mapping) assert len(slot_mapping) == len(input_tokens) expected_bs = _get_graph_batch_size(len(seq_group_metadata_list)) # Verify input metadata is correct for prompts. device = model_runner.device - assert attn_metadata.is_prompt is False - assert attn_metadata.seq_lens is None - assert attn_metadata.subquery_start_loc is None - assert attn_metadata.seq_start_loc is None - assert attn_metadata.max_seq_len == max(seq_lens) + assert attn_metadata.num_prefills == 0 + assert attn_metadata.num_prefill_tokens == 0 + seq_lens = [context_len + 1 for context_len in context_lens] + # seq_lens are padded to expected_bs + for _ in range(expected_bs - len(seq_lens)): + seq_lens.append(1) + assert attn_metadata.seq_lens == seq_lens + start_idx = 0 + start_loc = [start_idx] + for _ in context_lens: + # decode has only 1 token for query. + start_idx += 1 + start_loc.append(start_idx) + assert torch.allclose( + attn_metadata.query_start_loc, + torch.tensor(start_loc, dtype=torch.int32, device=device)) + + start_idx = 0 + seq_start_loc = [start_idx] + for seq_len in seq_lens: + start_idx += seq_len + seq_start_loc.append(start_idx) + assert torch.allclose( + attn_metadata.seq_start_loc, + torch.tensor(seq_start_loc, dtype=torch.int32, device=device)) + + assert torch.allclose( + attn_metadata.context_lens_tensor, + torch.tensor(context_lens, dtype=torch.int, device=device)) + assert attn_metadata.max_decode_seq_len == max(seq_lens) assert torch.allclose( attn_metadata.seq_lens_tensor[:len(seq_lens)], torch.tensor(seq_lens, dtype=torch.int, device=device)) @@ -185,23 +222,23 @@ def test_prepare_decode_cuda_graph(batch_size): # It is padded up to assert attn_metadata.block_tables.shape[1] == ( model_runner.get_max_block_per_batch()) - # Cuda graph should not be used for prerill. assert attn_metadata.use_cuda_graph is True assert len(input_tokens) == expected_bs assert len(input_positions) == expected_bs - assert input_tokens == input_positions + torch.allclose(input_tokens, input_positions) # Verify Sampling expected_selected_token_indices = [] selected_token_start_idx = 0 - for seq_len in seq_lens: + for _ in context_lens: expected_selected_token_indices.append(selected_token_start_idx) selected_token_start_idx += 1 sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, seq_lens, - query_lens=seq_lens, + # query lens is all 1 for decode. + query_lens=[1 for _ in range(len(context_lens))], device=model_runner.device, pin_memory=model_runner.pin_memory) actual = sampling_metadata.selected_token_indices @@ -220,15 +257,27 @@ def test_empty_seq_group(): enforce_eager=False, ) seq_group_metadata_list = [] - input_tokens, input_positions, attn_metadata, _, _, _, slot_mapping = ( - model_runner._prepare_decode(seq_group_metadata_list)) + model_input = model_runner._prepare_model_input(seq_group_metadata_list) + input_tokens, input_positions, attn_metadata, slot_mapping = ( + model_input.input_tokens, + model_input.input_positions, + model_input.attn_metadata, + model_input.slot_mapping, + ) assert len(input_tokens) == 0 assert len(input_positions) == 0 assert attn_metadata is None assert len(slot_mapping) == 0 - (input_tokens, input_positions, attn_metadata, return_seq_lens, _, _, _, _, - _, slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list)) + model_input = model_runner._prepare_model_input(seq_group_metadata_list) + (input_tokens, input_positions, attn_metadata, slot_mapping, + return_seq_lens) = ( + model_input.input_tokens, + model_input.input_positions, + model_input.attn_metadata, + model_input.slot_mapping, + model_input.seq_lens, + ) assert len(input_tokens) == 0 assert len(input_positions) == 0 assert attn_metadata is None @@ -285,9 +334,11 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): # Add decode requests for i in range(prefill_batch_size, batch_size): # make sure all tokens fit into one block - seq_len = i % (model_runner.block_size - 1) + 1 - prompt_toks = list(range(seq_len)) + context_len = i % (model_runner.block_size - 1) + 1 + prompt_toks = list(range(context_len)) seq_data = SequenceData(prompt_toks) + seq_data.append_token_id(1, 0) + seq_data.update_num_computed_tokens(context_len) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=False, @@ -308,23 +359,17 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): assert len(attn_metadata.slot_mapping) == len(input_tokens) assert len(input_positions) == len(input_tokens) assert attn_metadata.num_prefills == prefill_batch_size - if enforce_eager: - assert attn_metadata.num_decode_tokens == decode_batch_size - else: - assert attn_metadata.num_decode_tokens == _get_graph_batch_size( - decode_batch_size) + assert attn_metadata.num_decode_tokens == decode_batch_size assert attn_metadata.num_prefill_tokens == sum(seq_lens) # Verify attn metadata is consistent. We don't need to test individual # values here because they are tested above. - prefill_meta = model_runner._prepare_prompt( - prefill_metadata_list).attn_metadata - decode_meta = model_runner._prepare_decode( - decode_metadata_list).attn_metadata + attn_metadata = model_runner._prepare_model_input( + seq_group_metadata_list).attn_metadata - for attr_expected, attr_actual in zip(vars(prefill_meta), + for attr_expected, attr_actual in zip(vars(attn_metadata.prefill_metadata), vars(prefill_meta_actual)): assert attr_expected[1] == attr_actual[1] - for attr_expected, attr_actual in zip(vars(decode_meta), + for attr_expected, attr_actual in zip(vars(attn_metadata.decode_metadata), vars(decode_meta_actual)): assert attr_expected[1] == attr_actual[1] diff --git a/vllm/attention/__init__.py b/vllm/attention/__init__.py index 088f48def7668..f6bce9a187c64 100644 --- a/vllm/attention/__init__.py +++ b/vllm/attention/__init__.py @@ -1,6 +1,5 @@ from vllm.attention.backends.abstract import (AttentionBackend, - AttentionMetadata, - AttentionMetadataPerStage) + AttentionMetadata) from vllm.attention.layer import Attention from vllm.attention.selector import get_attn_backend @@ -8,6 +7,6 @@ "Attention", "AttentionBackend", "AttentionMetadata", - "AttentionMetadataPerStage", + "Attention", "get_attn_backend", ] diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 98d70fcab1a18..94ab64de30a94 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -21,7 +21,7 @@ def get_impl_cls() -> Type["AttentionImpl"]: @staticmethod @abstractmethod - def make_metadata(*args, **kwargs) -> "AttentionMetadataPerStage": + def make_metadata(*args, **kwargs) -> "AttentionMetadata": raise NotImplementedError @staticmethod @@ -53,8 +53,34 @@ def copy_blocks( @dataclass -class AttentionMetadataPerStage: - """Attention metadata for a specific stage. I.e., prefill or decode.""" +class AttentionMetadata: + """Attention metadata for prefill and decode batched together.""" + # Total number of prefill requests. + num_prefills: int + # Number of prefill tokens. + num_prefill_tokens: int + # Number of decode tokens. Note that it is equivalent to the number of + # decode requests. + num_decode_tokens: int + # (num_tokens,). The indices of the token slots that input tokens will be + # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size + # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot + # in block 0, and 1st slot in block 1, respectively. + slot_mapping: torch.Tensor + + @property + @abstractmethod + def prefill_metadata(self) -> Optional["AttentionMetadata"]: + """Return the attention metadata that's required to run prefill + attention.""" + pass + + @property + @abstractmethod + def decode_metadata(self) -> Optional["AttentionMetadata"]: + """Return the attention metadata that's required to run decode + attention.""" + pass def asdict_zerocopy(self, skip_fields: Optional[Set[str]] = None @@ -70,40 +96,10 @@ def asdict_zerocopy(self, } -T = TypeVar("T", bound=AttentionMetadataPerStage) - - -@dataclass -class AttentionMetadata(Generic[T]): - """Attention metadata for prefill and decode batched together.""" - # Total number of prefill requests. - num_prefills: int - # Number of prefill tokens. - num_prefill_tokens: int - # Number of decode tokens. Note that it is equivalent to the number of - # decode requests. - num_decode_tokens: int - # The attention metadata for prefill requests in a batch. - # None if there's no prefill requests in a batch. - prefill_metadata: Optional[T] - # The attention metadata for decode requests in a batch. - # None if there's no decode requests in a batch. - decode_metadata: Optional[T] - # (num_tokens,). The indices of the token slots that input tokens will be - # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size - # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot - # in block 0, and 1st slot in block 1, respectively. - slot_mapping: torch.Tensor - - def __post_init__(self): - if self.num_prefill_tokens > 0: - assert self.num_prefills > 0 - assert self.prefill_metadata is not None - if self.num_decode_tokens > 0: - assert self.decode_metadata is not None +T = TypeVar("T", bound=AttentionMetadata) -class AttentionImpl(ABC): +class AttentionImpl(ABC, Generic[T]): @abstractmethod def __init__( @@ -125,7 +121,7 @@ def forward( key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, + attn_metadata: T, kv_scale: float = 1.0, ) -> torch.Tensor: raise NotImplementedError diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index f59715bd76ede..5d1f65819ed4e 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -11,8 +11,7 @@ from vllm_flash_attn import flash_attn_varlen_func from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, - AttentionMetadataPerStage) + AttentionMetadata) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) @@ -58,8 +57,7 @@ def copy_blocks( @dataclass -class FlashAttentionMetadata(AttentionMetadataPerStage, - PagedAttentionMetadata): +class FlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): """Metadata for FlashAttentionBackend. NOTE: Any python object stored here is not updated when it is @@ -67,9 +65,6 @@ class FlashAttentionMetadata(AttentionMetadataPerStage, dynamically, it should be stored in tensor. The tensor has to be updated from `CUDAGraphRunner.forward` API. """ - # Currently, input sequences can only contain all prompts - # or all decoding. True if all sequences are prompts. - is_prompt: bool # (batch_size,). The sequence length per sequence. Sequence length means # the computed tokens + new tokens None if it is a decoding. seq_lens: Optional[List[int]] @@ -84,14 +79,18 @@ class FlashAttentionMetadata(AttentionMetadataPerStage, # |-------------------- seq_len ----------------------| # |-- query_len ---| - # Maximum query length in the batch. + # Maximum query length in the batch. None for decoding. max_query_len: Optional[int] - # Maximum sequence length in the batch. - max_seq_len: Optional[int] + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Maximum sequence length among decode batch. 0 if there are prefill + # requests only. + max_decode_seq_len: int # (batch_size + 1,). The cumulative subquery lengths of the sequences in # the batch, used to index into subquery. E.g., if the subquery length # is [4, 6], it is [0, 4, 10]. - subquery_start_loc: Optional[torch.Tensor] + query_start_loc: Optional[torch.Tensor] # (batch_size + 1,). The cumulative sequence lengths of the sequences in # the batch, used to index into sequence. E.g., if the sequence length is # [4, 6], it is [0, 4, 10]. @@ -105,6 +104,70 @@ class FlashAttentionMetadata(AttentionMetadataPerStage, # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. use_cuda_graph: bool + _cached_prefill_metadata: Optional["FlashAttentionMetadata"] = None + _cached_decode_metadata: Optional["FlashAttentionMetadata"] = None + + @property + def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]: + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + return self._cached_prefill_metadata + + assert self.seq_lens is not None + assert self.seq_lens_tensor is not None + assert self.query_start_loc is not None + assert self.context_lens_tensor is not None + assert self.block_tables is not None + assert self.seq_start_loc is not None + + self._cached_prefill_metadata = FlashAttentionMetadata( + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=self.slot_mapping[:self.num_prefill_tokens], + seq_lens=self.seq_lens[:self.num_prefills], + seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], + max_query_len=self.max_query_len, + max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_seq_len=0, + query_start_loc=self.query_start_loc[:self.num_prefills + 1], + seq_start_loc=self.seq_start_loc[:self.num_prefills + 1], + context_lens_tensor=self.context_lens_tensor[:self.num_prefills], + block_tables=self.block_tables[:self.num_prefills], + use_cuda_graph=False, + ) + return self._cached_prefill_metadata + + @property + def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + return self._cached_decode_metadata + assert self.block_tables is not None + assert self.seq_lens_tensor is not None + + self._cached_decode_metadata = FlashAttentionMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=self.slot_mapping[self.num_prefill_tokens:], + seq_lens=None, + seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], + max_query_len=None, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_decode_seq_len, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=self.block_tables[self.num_prefills:], + use_cuda_graph=self.use_cuda_graph, + ) + return self._cached_decode_metadata + class FlashAttentionImpl(AttentionImpl): """ @@ -168,7 +231,7 @@ def forward( key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata[FlashAttentionMetadata], + attn_metadata: FlashAttentionMetadata, kv_scale: float = 1.0, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. @@ -228,8 +291,8 @@ def forward( v=value, cu_seqlens_q=prefill_meta.seq_start_loc, cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_q=prefill_meta.max_seq_len, - max_seqlen_k=prefill_meta.max_seq_len, + max_seqlen_q=prefill_meta.max_prefill_seq_len, + max_seqlen_k=prefill_meta.max_prefill_seq_len, softmax_scale=self.scale, causal=True, window_size=self.sliding_window, @@ -249,7 +312,7 @@ def forward( key_cache, value_cache, prefill_meta.block_tables, - prefill_meta.subquery_start_loc, + prefill_meta.query_start_loc, prefill_meta.seq_lens_tensor, prefill_meta.context_lens_tensor, prefill_meta.max_query_len, @@ -264,7 +327,7 @@ def forward( value_cache, decode_meta.block_tables, decode_meta.seq_lens_tensor, - decode_meta.max_seq_len, + decode_meta.max_decode_seq_len, self.kv_cache_dtype, self.num_kv_heads, self.scale, diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 92d0fe0487516..5f9fd586fb70e 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -8,8 +8,7 @@ from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, - AttentionMetadataPerStage) + AttentionMetadata) class FlashInferBackend(AttentionBackend): @@ -56,9 +55,10 @@ def get_supported_head_sizes() -> List[int]: @dataclass -class FlashInferMetadata(AttentionMetadataPerStage): - - is_prompt: bool +class FlashInferMetadata(AttentionMetadata): + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int use_cuda_graph: bool = False @@ -67,7 +67,6 @@ class FlashInferMetadata(AttentionMetadataPerStage): # Metadata for the prefill stage since we still # use flash attention for prefill. seq_start_loc: Optional[torch.Tensor] = None - max_seq_len: Optional[int] = None block_tables: Optional[torch.Tensor] = None # Metadata for the decode stage @@ -113,7 +112,8 @@ def __post_init__(self): # When using flashinfer, we are also creating the FlashInferMetadata, # which will also call post_init by default, here we want to skip the # post_init if it's the prefill phase. - if not self.is_prompt: + if self.num_prefills == 0: + assert self.num_decode_tokens > 0 self.decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( self.workspace_buffer, "NHD") self.decode_wrapper.begin_forward( @@ -138,6 +138,24 @@ def asdict_zerocopy(self, skip_fields.add('decode_wrapper') return super().asdict_zerocopy(skip_fields) + @property + def prefill_metadata(self) -> Optional["FlashInferMetadata"]: + # Currently chunked prefill is not supported + if self.num_decode_tokens == 0: + assert self.num_prefills > 0 + return self + + return None + + @property + def decode_metadata(self) -> Optional["FlashInferMetadata"]: + # Currently chunked prefill is not supported + if self.num_prefills > 0: + assert self.num_decode_tokens == 0 + return None + + return self + class FlashInferImpl(AttentionImpl): @@ -172,7 +190,7 @@ def forward( key: torch.Tensor, value: torch.Tensor, kv_cache: Optional[torch.Tensor], - attn_metadata: AttentionMetadata[FlashInferMetadata], + attn_metadata: FlashInferMetadata, kv_scale: float = 1.0, ) -> torch.Tensor: assert kv_scale == 1.0 @@ -208,8 +226,8 @@ def forward( v=value, cu_seqlens_q=prefill_meta.seq_start_loc, cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_q=prefill_meta.max_seq_len, - max_seqlen_k=prefill_meta.max_seq_len, + max_seqlen_q=prefill_meta.max_prefill_seq_len, + max_seqlen_k=prefill_meta.max_prefill_seq_len, softmax_scale=self.scale, causal=True, window_size=self.sliding_window, diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 539585b46c7aa..1a94dc3596358 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -6,8 +6,7 @@ import vllm.envs as envs from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, - AttentionMetadataPerStage) + AttentionMetadata) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm.logger import init_logger @@ -56,8 +55,7 @@ def copy_blocks( @dataclass -class ROCmFlashAttentionMetadata(AttentionMetadataPerStage, - PagedAttentionMetadata): +class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): """Metadata for FlashAttentionBackend. NOTE: Any python object stored here is not updated when it is @@ -65,9 +63,6 @@ class ROCmFlashAttentionMetadata(AttentionMetadataPerStage, dynamically, it should be stored in tensor. The tensor has to be updated from `CUDAGraphRunner.forward` API. """ - # Currently, input sequences can only contain all prompts - # or all decoding. True if all sequences are prompts. - is_prompt: bool # (batch_size,). The sequence length per sequence. Sequence length means # the computed tokens + new tokens None if it is a decoding. seq_lens: Optional[List[int]] @@ -82,14 +77,18 @@ class ROCmFlashAttentionMetadata(AttentionMetadataPerStage, # |-------------------- seq_len ----------------------| # |-- query_len ---| - # Maximum query length in the batch. + # Maximum query length in the batch. None for decoding. max_query_len: Optional[int] - # Maximum sequence length in the batch. - max_seq_len: Optional[int] + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Maximum sequence length among decode batch. 0 if there are prefill + # requests only. + max_decode_seq_len: int # (batch_size + 1,). The cumulative subquery lengths of the sequences in # the batch, used to index into subquery. E.g., if the subquery length # is [4, 6], it is [0, 4, 10]. - subquery_start_loc: Optional[torch.Tensor] + query_start_loc: Optional[torch.Tensor] # (batch_size + 1,). The cumulative sequence lengths of the sequences in # the batch, used to index into sequence. E.g., if the sequence length is # [4, 6], it is [0, 4, 10]. @@ -102,6 +101,69 @@ class ROCmFlashAttentionMetadata(AttentionMetadataPerStage, # (batch_size,) A tensor of context lengths (tokens that are computed # so far). context_lens_tensor: Optional[torch.Tensor] + _cached_prefill_metadata: Optional["ROCmFlashAttentionMetadata"] = None + _cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None + + @property + def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + return self._cached_prefill_metadata + + assert self.seq_lens is not None + assert self.seq_lens_tensor is not None + assert self.query_start_loc is not None + assert self.context_lens_tensor is not None + assert self.block_tables is not None + assert self.seq_start_loc is not None + + self._cached_prefill_metadata = ROCmFlashAttentionMetadata( + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=self.slot_mapping[:self.num_prefill_tokens], + seq_lens=self.seq_lens[:self.num_prefills], + seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], + max_query_len=self.max_query_len, + max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_seq_len=0, + query_start_loc=self.query_start_loc[:self.num_prefills + 1], + seq_start_loc=self.seq_start_loc[:self.num_prefills + 1], + context_lens_tensor=self.context_lens_tensor[:self.num_prefills], + block_tables=self.block_tables[:self.num_prefills], + use_cuda_graph=False, + ) + return self._cached_prefill_metadata + + @property + def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + return self._cached_decode_metadata + assert self.block_tables is not None + assert self.seq_lens_tensor is not None + + self._cached_decode_metadata = ROCmFlashAttentionMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=self.slot_mapping[self.num_prefill_tokens:], + seq_lens=None, + seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], + max_query_len=None, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_decode_seq_len, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=self.block_tables[self.num_prefills:], + use_cuda_graph=self.use_cuda_graph, + ) + return self._cached_decode_metadata class ROCmFlashAttentionImpl(AttentionImpl): @@ -198,7 +260,7 @@ def forward( key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata[ROCmFlashAttentionMetadata], + attn_metadata: ROCmFlashAttentionMetadata, kv_scale: float = 1.0, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. @@ -266,8 +328,8 @@ def forward( None, prefill_meta.seq_start_loc, prefill_meta.seq_start_loc, - prefill_meta.max_seq_len, - prefill_meta.max_seq_len, + prefill_meta.max_prefill_seq_len, + prefill_meta.max_prefill_seq_len, True, self.scale, ) @@ -290,8 +352,8 @@ def forward( v=value, cu_seqlens_q=prefill_meta.seq_start_loc, cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_q=prefill_meta.max_seq_len, - max_seqlen_k=prefill_meta.max_seq_len, + max_seqlen_q=prefill_meta.max_prefill_seq_len, + max_seqlen_k=prefill_meta.max_prefill_seq_len, softmax_scale=self.scale, causal=True, ) @@ -308,7 +370,7 @@ def forward( key_cache, value_cache, prefill_meta.block_tables, - prefill_meta.subquery_start_loc, + prefill_meta.query_start_loc, prefill_meta.seq_lens_tensor, prefill_meta.context_lens_tensor, prefill_meta.max_query_len, @@ -324,7 +386,7 @@ def forward( value_cache, decode_meta.block_tables, decode_meta.seq_lens_tensor, - decode_meta.max_seq_len, + decode_meta.max_decode_seq_len, self.kv_cache_dtype, self.num_kv_heads, self.scale, diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 2dd72a00c6e30..a3f72b9c94566 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -7,8 +7,7 @@ from torch.nn.functional import scaled_dot_product_attention from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, - AttentionMetadataPerStage) + AttentionMetadata) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) @@ -54,8 +53,7 @@ def copy_blocks( @dataclass -class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata, - AttentionMetadataPerStage): +class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata): """Metadata for TorchSDPABackend. """ # Currently, input sequences can only contain all prompts @@ -72,8 +70,26 @@ def __post_init__(self): # will not appear in the __repr__ and __init__ self.attn_bias: Optional[List[torch.Tensor]] = None + @property + def prefill_metadata(self) -> Optional["TorchSDPAMetadata"]: + # Currently chunked prefill is not supported + if self.num_decode_tokens == 0: + assert self.num_prefills > 0 + return self -class TorchSDPABackendImpl(AttentionImpl): + return None + + @property + def decode_metadata(self) -> Optional["TorchSDPAMetadata"]: + # Currently chunked prefill is not supported + if self.num_prefills > 0: + assert self.num_decode_tokens == 0 + return None + + return self + + +class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): def __init__( self, @@ -200,7 +216,7 @@ def forward( value_cache, attn_metadata.block_tables, attn_metadata.seq_lens_tensor, - attn_metadata.max_seq_len, + attn_metadata.max_decode_seq_len, self.kv_cache_dtype, self.num_kv_heads, self.scale, diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index cb2028553461f..fc46af054de4f 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -9,8 +9,7 @@ LowerTriangularMaskWithTensorBias) from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, - AttentionMetadataPerStage) + AttentionMetadata) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm.logger import init_logger @@ -59,7 +58,7 @@ def copy_blocks( @dataclass -class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata): +class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): """Metadata for XFormersbackend. NOTE: Any python object stored here is not updated when it is @@ -67,9 +66,6 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata): dynamically, it should be stored in tensor. The tensor has to be updated from `CUDAGraphRunner.forward` API. """ - # Currently, input sequences can only contain all prompts - # or all decoding. True if all sequences are prompts. - is_prompt: bool # (batch_size,). The sequence length per sequence. Sequence length means # the computed tokens + new tokens None if it is a decoding. seq_lens: Optional[List[int]] @@ -83,15 +79,19 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata): # |-------------------- seq_len ----------------------| # |-- query_len ---| - # Maximum query length in the batch. + # Maximum query length in the batch. None for decoding. max_query_len: Optional[int] # FIXME: It is for flash attn. - # Maximum sequence length in the batch. - max_seq_len: Optional[int] + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Maximum sequence length among decode batch. 0 if there are prefill + # requests only. + max_decode_seq_len: int # (batch_size + 1,). The cumulative subquery lengths of the sequences in # the batch, used to index into subquery. E.g., if the subquery length # is [4, 6], it is [0, 4, 10]. - subquery_start_loc: Optional[torch.Tensor] + query_start_loc: Optional[torch.Tensor] # FIXME: It is for flash attn. # (batch_size + 1,). The cumulative sequence lengths of the sequences in # the batch, used to index into sequence. E.g., if the sequence length is @@ -105,6 +105,8 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata): # Cuda-graph is currently enabled for decoding only. # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. use_cuda_graph: bool + _cached_prefill_metadata: Optional["XFormersMetadata"] = None + _cached_decode_metadata: Optional["XFormersMetadata"] = None def __post_init__(self): # Set during the execution of the first attention op. @@ -114,8 +116,68 @@ def __post_init__(self): # will not appear in the __repr__ and __init__ self.attn_bias: Optional[List[AttentionBias]] = None - -class XFormersImpl(AttentionImpl): + @property + def prefill_metadata(self) -> Optional["XFormersMetadata"]: + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + return self._cached_prefill_metadata + + assert self.seq_lens is not None + assert self.seq_lens_tensor is not None + assert self.query_start_loc is not None + assert self.context_lens_tensor is not None + assert self.block_tables is not None + + self._cached_prefill_metadata = XFormersMetadata( + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=self.slot_mapping[:self.num_prefill_tokens], + seq_lens=self.seq_lens[:self.num_prefills], + seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], + max_query_len=self.max_query_len, + max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_seq_len=0, + query_start_loc=self.query_start_loc[:self.num_prefills + 1], + seq_start_loc=None, + context_lens_tensor=self.context_lens_tensor[:self.num_prefills], + block_tables=self.block_tables[:self.num_prefills], + use_cuda_graph=False, + ) + return self._cached_prefill_metadata + + @property + def decode_metadata(self) -> Optional["XFormersMetadata"]: + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + return self._cached_decode_metadata + assert self.block_tables is not None + assert self.seq_lens_tensor is not None + + self._cached_decode_metadata = XFormersMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=self.slot_mapping[self.num_prefill_tokens:], + seq_lens=None, + seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], + max_query_len=None, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_decode_seq_len, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=self.block_tables[self.num_prefills:], + use_cuda_graph=self.use_cuda_graph, + ) + return self._cached_decode_metadata + + +class XFormersImpl(AttentionImpl[XFormersMetadata]): """ If the input tensors contain prompt tokens, the layout is as follows: |<--------------- num_prefill_tokens ----------------->| @@ -176,7 +238,7 @@ def forward( key: torch.Tensor, value: torch.Tensor, kv_cache: Optional[torch.Tensor], - attn_metadata: AttentionMetadata[XFormersMetadata], + attn_metadata: "XFormersMetadata", kv_scale: float = 1.0, ) -> torch.Tensor: """Forward pass with xFormers and PagedAttention. @@ -244,7 +306,7 @@ def forward( key_cache, value_cache, prefill_meta.block_tables, - prefill_meta.subquery_start_loc, + prefill_meta.query_start_loc, prefill_meta.seq_lens_tensor, prefill_meta.context_lens_tensor, prefill_meta.max_query_len, @@ -261,7 +323,7 @@ def forward( value_cache, decode_meta.block_tables, decode_meta.seq_lens_tensor, - decode_meta.max_seq_len, + decode_meta.max_decode_seq_len, self.kv_cache_dtype, self.num_kv_heads, self.scale, diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 8a872dba8c877..126692d8c9b40 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -4,8 +4,7 @@ import torch import torch.nn as nn -from vllm.attention.backends.abstract import (AttentionMetadata, - AttentionMetadataPerStage) +from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.selector import get_attn_backend from vllm.config import CacheConfig @@ -57,7 +56,7 @@ def forward( key: torch.Tensor, value: torch.Tensor, kv_cache: Optional[torch.Tensor], - attn_metadata: AttentionMetadata[AttentionMetadataPerStage], + attn_metadata: AttentionMetadata, kv_scale: float = 1.0, ) -> torch.Tensor: return self.impl.forward(query, key, value, kv_cache, attn_metadata, diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index 3c010b67b3120..30feaa4da254d 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -16,8 +16,8 @@ class PagedAttentionMetadata: # (batch_size,). The length of sequences (entire tokens seen so far) per # sequence. seq_lens_tensor: Optional[torch.Tensor] - # Maximum sequence length in the batch. - max_seq_len: Optional[int] + # Maximum sequence length in the batch. 0 if it is prefill-only batch. + max_decode_seq_len: int # (batch_size, max_blocks_per_seq). # Block addresses per sequence. (Seq id -> list of physical block) # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks @@ -166,7 +166,7 @@ def forward_prefix( key_cache: torch.Tensor, value_cache: torch.Tensor, block_tables: torch.Tensor, - subquery_start_loc: torch.Tensor, + query_start_loc: torch.Tensor, seq_lens_tensor: torch.Tensor, context_lens: torch.Tensor, max_query_len: int, @@ -182,8 +182,8 @@ def forward_prefix( key_cache, value_cache, block_tables, - # subquery_start_loc is (batch_size + 1,) - subquery_start_loc[:-1], + # query_start_loc is (batch_size + 1,) + query_start_loc[:-1], seq_lens_tensor, context_lens, max_query_len, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 195d9e1b33e3c..bd44c2470182b 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -618,6 +618,11 @@ def create_engine_config(self, ) -> EngineConfig: decoding_config = DecodingConfig( guided_decoding_backend=self.guided_decoding_backend) + if (model_config.get_sliding_window() is not None + and scheduler_config.chunked_prefill_enabled): + raise ValueError( + "Chunked prefill is not supported with sliding window.") + return EngineConfig(model_config=model_config, cache_config=cache_config, parallel_config=parallel_config, diff --git a/vllm/model_executor/layers/rejection_sampler.py b/vllm/model_executor/layers/rejection_sampler.py index b5f1e55d0e839..1f2ab7e2870ca 100644 --- a/vllm/model_executor/layers/rejection_sampler.py +++ b/vllm/model_executor/layers/rejection_sampler.py @@ -122,6 +122,7 @@ def forward( draft_token_ids, bonus_token_ids, ) + return output_token_ids def _batch_modified_rejection_sampling( diff --git a/vllm/sequence.py b/vllm/sequence.py index 12e930c27173e..aa759448d82b1 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -654,8 +654,9 @@ def lora_int_id(self) -> int: return self.lora_request.lora_int_id if self.lora_request else 0 @property - def token_chunk_size(self) -> Optional[int]: + def token_chunk_size(self) -> int: """Return the number of tokens to be processed (chunk size).""" + assert self._token_chunk_size is not None return self._token_chunk_size diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py index d5fd96907ddd7..7792f3a3425cc 100644 --- a/vllm/spec_decode/batch_expansion.py +++ b/vllm/spec_decode/batch_expansion.py @@ -293,21 +293,30 @@ def _create_single_target_seq_group_metadata( prompt_token_ids = seq_data.get_prompt_token_ids() new_output_token_ids = [*seq_data.get_output_token_ids(), *token_ids] + new_seq_data_dict = { + target_seq_id: + SequenceData( + prompt_token_ids=prompt_token_ids, + output_token_ids=new_output_token_ids, + ), + } + # This is a hack. Technically, spec decoding should compute + # num_lookahead slots at one shot, but instead, it expands the batch + # and evaluate one by one right now. context_len is seq_len - 1 because + # the kv cache is filled by a previous batch in the batch expansion. + for data in new_seq_data_dict.values(): + data.update_num_computed_tokens(data.get_len() - 1) + return SequenceGroupMetadata( request_id=seq_group_metadata.request_id, is_prompt=seq_group_metadata.is_prompt, - seq_data={ - target_seq_id: - SequenceData( - prompt_token_ids=prompt_token_ids, - output_token_ids=new_output_token_ids, - ), - }, + seq_data=new_seq_data_dict, sampling_params=seq_group_metadata.sampling_params, block_tables={ target_seq_id: seq_group_metadata.block_tables[seq_id], }, lora_request=None, + token_chunk_size=1, ) def _split_scoring_output( diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index 20098ebaeea32..b5a805278d273 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -114,6 +114,7 @@ def _append_new_tokens( token_logprob = seq_output.logprobs[token_id] seq.append_token_id(token_id, token_logprob.logprob) + seq.update_num_computed_tokens(1) def _shallow_copy_inputs( self, seq_group_metadata_list: List[SequenceGroupMetadata] diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 0a0b0d70cfe21..bc88f2c5bed6c 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -159,12 +159,10 @@ def _prepare_prompt( is_prompt=True, seq_lens=seq_lens, seq_lens_tensor=None, - max_seq_len=None, + max_decode_seq_len=None, num_prefills=len(seq_lens), num_prefill_tokens=num_prompt_tokens, num_decode_tokens=0, - prefill_metadata=None, - decode_metadata=None, block_tables=torch.tensor([]), slot_mapping=slot_mapping, ) @@ -213,7 +211,7 @@ def _prepare_decode( block_table = block_table[-sliding_window_blocks:] block_tables.append(block_table) - max_seq_len = max(seq_lens) + max_decode_seq_len = max(seq_lens) input_tokens = torch.tensor(input_tokens, dtype=torch.long, @@ -243,12 +241,10 @@ def _prepare_decode( slot_mapping=slot_mapping, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, - max_seq_len=max_seq_len, + max_decode_seq_len=max_decode_seq_len, num_prefill_tokens=0, num_decode_tokens=len(input_tokens), num_prefills=0, - prefill_metadata=None, - decode_metadata=None, block_tables=block_tables, ) return ( diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index d04bebbdc31b6..91f30978ead87 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -13,7 +13,7 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.pooling_params import PoolingParams from vllm.sequence import PoolerOutput, SequenceData, SequenceGroupMetadata -from vllm.worker.model_runner import BatchType, ModelRunner +from vllm.worker.model_runner import ModelRunner logger = init_logger(__name__) @@ -88,85 +88,24 @@ def prepare_input_tensors( ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, PoolingMetadata, Set[LoRARequest], LoRAMapping, torch.Tensor]: if self.is_driver_worker: - prefill_reqs = [] - decode_reqs = [] - for seq_group_meta in seq_group_metadata_list: - if seq_group_meta.is_prompt: - prefill_reqs.append(seq_group_meta) - else: - decode_reqs.append(seq_group_meta) - # Prepare input tensors. ( input_tokens, input_positions, - prefill_attn_metadata, - prompt_lens, - subquery_lens, - lora_index_mapping, - lora_prompt_mapping, + attn_metadata, + seq_lens, + _, + lora_mapping, lora_requests, multi_modal_input, slot_mapping, - ) = self._prepare_prompt(prefill_reqs) - ( - decode_input_tokens, - decode_input_positions, - decode_attn_metadata, - decode_lora_index_mapping, - decode_lora_prompt_mapping, - decode_lora_requests, - decode_slot_mapping, - ) = self._prepare_decode(decode_reqs) - + num_prefill_tokens, + num_decode_tokens, + num_prefills, + ) = self._prepare_model_input(seq_group_metadata_list) # Prepare PoolingMetadata pooling_metadata = self._prepare_pooling(seq_group_metadata_list, - prompt_lens) - - if not self.scheduler_config.chunked_prefill_enabled: - assert (len(prefill_reqs) and len(decode_reqs)) == 0 - - num_prefills = len(prompt_lens) - num_prefill_tokens = len(input_tokens) - num_decode_tokens = len(decode_input_tokens) - - # Coalesce tensors. Note that attn_metadata is currently not - # coalesced for simplicity. - input_tokens.extend(decode_input_tokens) - input_positions.extend(decode_input_positions) - slot_mapping.extend(decode_slot_mapping) - lora_index_mapping.extend(decode_lora_index_mapping) - lora_prompt_mapping.extend(decode_lora_prompt_mapping) - lora_requests.update(decode_lora_requests) - - input_tokens = torch.tensor(input_tokens, - dtype=torch.long, - device=self.device) - input_positions = torch.tensor(input_positions, - dtype=torch.long, - device=self.device) - slot_mapping = torch.tensor(slot_mapping, - dtype=torch.long, - device=self.device) - - if self.lora_config: - lora_mapping = LoRAMapping( - lora_index_mapping, - lora_prompt_mapping, - ) - else: - lora_mapping = None - - # Broadcast the metadata. - # If batch contains both prefill and decode, it sends 2 broadcasts. - # If it only contains 1 type, it triggers a single broadcast. - if (prefill_attn_metadata is not None - and decode_attn_metadata is not None): - batch_type = BatchType.MIXED - elif prefill_attn_metadata is not None: - batch_type = BatchType.PREFILL - else: - batch_type = BatchType.DECODE + seq_lens) metadata_dict = { "input_tokens": input_tokens, @@ -178,65 +117,26 @@ def prepare_input_tensors( "num_decode_tokens": num_decode_tokens, "slot_mapping": slot_mapping, "num_prefills": num_prefills, - "batch_type": batch_type, } - if prefill_attn_metadata is not None: - metadata_dict.update(prefill_attn_metadata.asdict_zerocopy()) - else: - assert decode_attn_metadata is not None - metadata_dict.update(decode_attn_metadata.asdict_zerocopy()) + if attn_metadata: + metadata_dict.update(attn_metadata.asdict_zerocopy()) broadcast_tensor_dict(metadata_dict, src=0) - - # Broadcast decode attn metadata for mixed batch type. - # The additional broadcast costs 300us overhead on 4 A10 GPUs. - # We can potentially reduce the overhead by coelescing tensors. - if batch_type == BatchType.MIXED: - assert decode_attn_metadata is not None - metadata_dict = decode_attn_metadata.asdict_zerocopy() - broadcast_tensor_dict(metadata_dict, src=0) else: metadata_dict = broadcast_tensor_dict(src=0) input_tokens = metadata_dict.pop("input_tokens") input_positions = metadata_dict.pop("input_positions") - slot_mapping = metadata_dict.pop("slot_mapping") - num_prefills = metadata_dict.pop("num_prefills") lora_mapping = metadata_dict.pop("lora_mapping") lora_requests = metadata_dict.pop("lora_requests") multi_modal_input = metadata_dict.pop("multi_modal_input") - num_prefill_tokens = metadata_dict.pop("num_prefill_tokens") - num_decode_tokens = metadata_dict.pop("num_decode_tokens") - batch_type = metadata_dict.pop("batch_type") - - # Create an attention metadata. - prefill_attn_metadata = None - decode_attn_metadata = None - if batch_type == BatchType.PREFILL or batch_type == BatchType.MIXED: - prefill_attn_metadata = self.attn_backend.make_metadata( + if metadata_dict: + attn_metadata = self.attn_backend.make_metadata( **metadata_dict) else: - decode_attn_metadata = self.attn_backend.make_metadata( - **metadata_dict) - + attn_metadata = None pooling_metadata = PoolingMetadata(seq_groups=None, seq_data=None, prompt_lens=None) - # if it is a mixed batch, decode attn_metadata is broadcasted - # separately. - if batch_type == BatchType.MIXED: - metadata_dict = broadcast_tensor_dict(src=0) - decode_attn_metadata = self.attn_backend.make_metadata( - **metadata_dict) - - attn_metadata = AttentionMetadata( - num_prefills=num_prefills, - slot_mapping=slot_mapping, - num_prefill_tokens=num_prefill_tokens, - num_decode_tokens=num_decode_tokens, - prefill_metadata=prefill_attn_metadata, - decode_metadata=decode_attn_metadata, - ) - return (input_tokens, input_positions, attn_metadata, pooling_metadata, lora_requests, lora_mapping, multi_modal_input) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index b5e1991717b13..dcdd4b962454e 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1,13 +1,11 @@ import time -from enum import IntEnum from typing import Dict, List, NamedTuple, Optional, Set, Tuple, Union import numpy as np import torch import torch.nn as nn -from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage, - get_attn_backend) +from vllm.attention import AttentionMetadata, get_attn_backend from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) @@ -37,66 +35,38 @@ ] -class PreparePromptMetadata(NamedTuple): - input_tokens: List[int] - input_positions: List[int] - attn_metadata: Optional[AttentionMetadataPerStage] +class ModelInput(NamedTuple): + input_tokens: torch.Tensor + input_positions: torch.Tensor + attn_metadata: Optional[AttentionMetadata] seq_lens: List[int] query_lens: List[int] - lora_index_mapping: List[int] - lora_prompt_mapping: List[int] + lora_mapping: Optional[LoRAMapping] lora_requests: Set[LoRARequest] multi_modal_input: Optional[torch.Tensor] - slot_mapping: List[int] + slot_mapping: torch.Tensor + num_prefill_tokens: int + num_decode_tokens: int + num_prefills: int @classmethod - def empty(cls): - return PreparePromptMetadata( - input_tokens=[], - input_positions=[], + def empty(cls, device): + return ModelInput( + input_tokens=torch.empty(0, device=device), + input_positions=torch.empty(0, device=device), attn_metadata=None, seq_lens=[], query_lens=[], - lora_index_mapping=[], - lora_prompt_mapping=[], + lora_mapping=None, lora_requests=set(), multi_modal_input=None, - slot_mapping=[], - ) - - -class PrepareDecodeMetadata(NamedTuple): - input_tokens: List[int] - input_positions: List[int] - attn_metadata: Optional[AttentionMetadata] - lora_index_mapping: List[int] - lora_prompt_mapping: List[int] - lora_requests: Set[LoRARequest] - slot_mapping: List[int] - - @classmethod - def empty(cls): - return PrepareDecodeMetadata( - input_tokens=[], - input_positions=[], - attn_metadata=None, - lora_index_mapping=[], - lora_prompt_mapping=[], - lora_requests=set(), - slot_mapping=[], + slot_mapping=torch.empty(0, device=device), + num_prefill_tokens=0, + num_decode_tokens=0, + num_prefills=0, ) -# How batches are constructed. -class BatchType(IntEnum): - # Every batch is prefill. - PREFILL = 0 - # Every batch is decode. - DECODE = 1 - # Batch is a mixture of prefill and decode. - MIXED = 2 - - class ModelRunner: def __init__( @@ -216,10 +186,22 @@ def get_max_block_per_batch(self) -> int: block_size = self.block_size return (self.max_seq_len_to_capture + block_size - 1) // block_size - def _prepare_prompt( + def _prepare_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> PreparePromptMetadata: + ) -> ModelInput: + """Prepare the model input based on a given sequence group. + + The API assumes seq_group_metadata_list is sorted by prefill -> decode. + + The result tensors and data structure also batches input in prefill + -> decode order. For example, + + - input_tokens[:num_prefill_tokens] contains prefill tokens. + - input_tokens[num_prefill_tokens:] contains decode tokens. + + If cuda graph is required, this API automatically pads inputs. + """ input_tokens: List[int] = [] input_positions: List[int] = [] slot_mapping: List[int] = [] @@ -228,212 +210,16 @@ def _prepare_prompt( lora_requests: Set[LoRARequest] = set() seq_lens: List[int] = [] + prefill_seq_lens: List[int] = [] + decode_seq_lens: List[int] = [] context_lens: List[int] = [] query_lens: List[int] = [] - prefix_block_tables: List[List[int]] = [] - multi_modal_input_list: List[torch.Tensor] = [] - - if len(seq_group_metadata_list) == 0: - return PreparePromptMetadata.empty() - - for seq_group_metadata in seq_group_metadata_list: - assert seq_group_metadata.is_prompt - seq_ids = list(seq_group_metadata.seq_data.keys()) - assert len(seq_ids) == 1 - seq_id = seq_ids[0] - - computed_block_nums = seq_group_metadata.computed_block_nums - if (self.scheduler_config is not None - and self.scheduler_config.chunked_prefill_enabled - and not (computed_block_nums is None - or computed_block_nums == [])): - raise RuntimeError( - "chunked prefill cannot be used with prefix caching " - "now.") - - token_chunk_size = seq_group_metadata.token_chunk_size - seq_data = seq_group_metadata.seq_data[seq_id] - context_len = seq_data.get_num_computed_tokens() - # We should use get_len here because in case of preemption - # it contains output tokens. - seq_len = min(seq_data.get_len(), context_len + token_chunk_size) - prompt_tokens = seq_data.get_token_ids()[context_len:seq_len] - seq_lens.append(seq_len) - - # NOTE: This only works for oooooooxxx style attention. - if computed_block_nums is not None and len( - computed_block_nums) > 0 and self.sliding_window is None: - # Prefix is not supported with sliding_window - context_len = len(computed_block_nums) * self.block_size - prompt_tokens = prompt_tokens[context_len:] - prefix_block_tables.append(computed_block_nums) - elif self.scheduler_config.chunked_prefill_enabled: - if seq_group_metadata.block_tables is not None: - # Prefill has chunked before. - block_table = seq_group_metadata.block_tables[seq_id] - prefix_block_tables.append(block_table) - else: - # The first prefill. - prefix_block_tables.append([]) - else: - prefix_block_tables.append([]) - # Right now, prefill start is always 0. However, this - # assumption can be changed once chunked prefill is introduced. - assert context_len == 0 - - # actual prompt lens - context_lens.append(context_len) - query_lens.append(seq_len - context_len) - - input_tokens.extend(prompt_tokens) - # NOTE(woosuk): Here we assume that the first token in the prompt - # is always the first token in the sequence. - input_positions.extend(list(range(context_len, seq_len))) - lora_id = seq_group_metadata.lora_int_id - - if lora_id > 0: - lora_requests.add(seq_group_metadata.lora_request) - - lora_index_mapping += [lora_id] * (seq_len - context_len) - lora_prompt_mapping.extend([lora_id] * ( - seq_len - context_len if seq_group_metadata.sampling_params - and seq_group_metadata.sampling_params.prompt_logprobs else 1)) - - if seq_group_metadata.multi_modal_data: - multi_modal_input_list.append( - seq_group_metadata.multi_modal_data.data) - - if _is_block_tables_empty(seq_group_metadata.block_tables): - # During memory profiling, the block tables are not initialized - # yet. In this case, we just use a dummy slot mapping. - # In embeddings, the block tables are {seq_id: None}. - slot_mapping.extend([_PAD_SLOT_ID] * seq_len) - continue - - # Compute the slot mapping. - block_table = seq_group_metadata.block_tables[seq_id] - - # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, - # where start_idx is max(0, seq_len - sliding_window). - # For example, if the prompt len is 10, sliding window is 8, and - # block size is 4, the first two tokens are masked and the slot - # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. - start_idx = 0 - if self.sliding_window is not None: - assert context_len == 0, ( - "Prefix caching is currently not supported with " - "sliding window attention") - start_idx = max(0, seq_len - self.sliding_window) - - for i in range(context_len, seq_len): - if i < start_idx: - slot_mapping.append(_PAD_SLOT_ID) - continue - - block_number = block_table[i // self.block_size] - block_offset = i % self.block_size - slot = block_number * self.block_size + block_offset - slot_mapping.append(slot) - - max_query_len = max(query_lens) - max_seq_len = max(seq_lens) - assert max_query_len > 0 - - context_lens_tensor = torch.tensor(context_lens, - dtype=torch.int, - device=self.device) - - if multi_modal_input_list: - assert self.vision_language_config, ( - "Multi-modal inputs are only supported by " - "vision language models.") - multi_modal_input = torch.cat(multi_modal_input_list, - dim=0).to(self.device) - else: - multi_modal_input = None - - # Prepare prefix block tables - max_prompt_block_table_len = max(len(t) for t in prefix_block_tables) - block_tables = make_tensor_with_pad( - prefix_block_tables, - max_len=max_prompt_block_table_len, - pad=0, - dtype=torch.int, - device=self.device, - ) - - # Query length can be shorter than key (i.e., prompt) when prefill - # is chunked or prefix cached. - query_lens_tensor = torch.tensor(query_lens, - dtype=torch.long, - device=self.device) - subquery_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=self.device) - - seq_lens_tensor = torch.tensor(seq_lens, - dtype=torch.int, - device=self.device) - seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=self.device) - - torch.cumsum(query_lens_tensor, - dim=0, - dtype=subquery_start_loc.dtype, - out=subquery_start_loc[1:]) - - torch.cumsum(seq_lens_tensor, - dim=0, - dtype=seq_start_loc.dtype, - out=seq_start_loc[1:]) - - if self.attn_backend.get_name() == "flashinfer": - attn_metadata = self.attn_backend.make_metadata( - is_prompt=True, - use_cuda_graph=False, - seq_start_loc=seq_start_loc, - max_seq_len=max_seq_len, - block_tables=block_tables) - else: - attn_metadata = self.attn_backend.make_metadata( - is_prompt=True, - seq_lens=seq_lens, - seq_lens_tensor=seq_lens_tensor, - max_query_len=max_query_len, - max_seq_len=max_seq_len, - subquery_start_loc=subquery_start_loc, - seq_start_loc=seq_start_loc, - context_lens_tensor=context_lens_tensor, - block_tables=block_tables, - use_cuda_graph=False, - ) - - return PreparePromptMetadata( - input_tokens=input_tokens, - input_positions=input_positions, - attn_metadata=attn_metadata, - seq_lens=seq_lens, - query_lens=query_lens, - lora_index_mapping=lora_index_mapping, - lora_prompt_mapping=lora_prompt_mapping, - lora_requests=lora_requests, - multi_modal_input=multi_modal_input, - slot_mapping=slot_mapping, - ) - - def _prepare_decode( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> PrepareDecodeMetadata: - input_tokens: List[int] = [] - input_positions: List[int] = [] - slot_mapping: List[int] = [] - seq_lens: List[int] = [] block_tables: List[List[int]] = [] - lora_index_mapping: List[int] = [] - lora_prompt_mapping: List[int] = [] - lora_requests: Set[LoRARequest] = set() + multi_modal_input_list: List[torch.Tensor] = [] + decode_only = True + num_prefills = 0 + num_prefill_tokens = 0 + num_decode_tokens = 0 # The following fields are only for flashinfer # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout @@ -454,60 +240,186 @@ def _prepare_decode( paged_kv_last_page_len: List[int] = [] if len(seq_group_metadata_list) == 0: - return PrepareDecodeMetadata.empty() + return ModelInput.empty(self.device) for seq_group_metadata in seq_group_metadata_list: - assert not seq_group_metadata.is_prompt - assert seq_group_metadata.token_chunk_size == 1 - seq_ids = list(seq_group_metadata.seq_data.keys()) - lora_id = seq_group_metadata.lora_int_id - - if lora_id > 0: - lora_requests.add(seq_group_metadata.lora_request) + is_prompt = seq_group_metadata.is_prompt for seq_id in seq_ids: + computed_block_nums = seq_group_metadata.computed_block_nums + if (self.scheduler_config is not None + and self.scheduler_config.chunked_prefill_enabled + and not (computed_block_nums is None + or computed_block_nums == [])): + raise RuntimeError( + "chunked prefill cannot be used with prefix caching " + "now.") + seq_data = seq_group_metadata.seq_data[seq_id] - generation_token = seq_data.get_last_token_id() - input_tokens.append(generation_token) + if is_prompt: + context_len = seq_data.get_num_computed_tokens() + else: + # get_num_computed_tokens is incorrect for spec decoding. + # So, we should have a special logic here. + # TODO(sang): Fix it. + context_len = seq_data.get_len() - 1 + + seq_len = min( + seq_data.get_len(), + context_len + seq_group_metadata.token_chunk_size) + if is_prompt: + tokens = seq_data.get_token_ids()[context_len:seq_len] + else: + # Optimization. get_token_ids requires the entire copy of + # tokens. + tokens = [seq_data.get_last_token_id()] + + # Prefix cache was hit. + # Prefix is not supported with sliding_window + prefix_cache_hit = (computed_block_nums is not None + and len(computed_block_nums) > 0 + and self.sliding_window is None + and is_prompt) + + # TODO(sang): Combine chunked prefill and prefix caching by + # only allowing multiple of block_size chunk size. + # NOTE: This only works for oooooooxxx style attention. + if prefix_cache_hit: + assert computed_block_nums is not None + context_len = len(computed_block_nums) * self.block_size + tokens = tokens[context_len:] + if self.attn_backend.get_name() == "flash-attn": + # NOTE(woosuk): For flash-attn, the block table should + # include the entries for the incoming prefill tokens. + # TODO(woosuk): This is a temporary fix. We should + # provide a unified interface for different backends. + block_table = seq_group_metadata.block_tables[seq_id] + else: + block_table = computed_block_nums + elif (self.scheduler_config.chunked_prefill_enabled + or not is_prompt): + if seq_group_metadata.block_tables is not None: + # chunked prefill or decode + block_table = seq_group_metadata.block_tables[seq_id] + if self.sliding_window is not None: + # chunked prefill doesn't support sliding window. + assert (not self.scheduler_config. + chunked_prefill_enabled) + sliding_window_blocks = (self.sliding_window // + self.block_size) + block_table = block_table[-sliding_window_blocks:] + + if self.attn_backend.get_name() == "flashinfer": + paged_kv_indices.extend(block_table) + paged_kv_indptr.append(paged_kv_indptr[-1] + + len(block_table)) + last_page_len = seq_data.get_len( + ) % self.block_size + if last_page_len == 0: + last_page_len = self.block_size + paged_kv_last_page_len.append(last_page_len) + else: + # Only happens when memory profiling runs. + block_table = [] + else: + # Prefill without chunked prefill or memory profiling. + block_table = [] + block_tables.append(block_table) - seq_len = seq_data.get_len() - position = seq_len - 1 - input_positions.append(position) + # TODO(sang): This is a hack to make sliding window work with + # paged attn. We can remove it if we make paged attn kernel + # to properly handle slinding window attn. + if (self.sliding_window is not None and not is_prompt): + seq_len = min(seq_len, self.sliding_window) + context_len = seq_len - 1 - seq_len = seq_len if self.sliding_window is None else min( - seq_len, self.sliding_window) seq_lens.append(seq_len) + context_lens.append(context_len) + query_len = seq_len - context_len + query_lens.append(query_len) + input_tokens.extend(tokens) + input_positions.extend(list(range(context_len, seq_len))) + lora_id = seq_group_metadata.lora_int_id + + if is_prompt: + assert len(seq_ids) == 1 + num_prefills += 1 + num_prefill_tokens += len(tokens) + decode_only = False + prefill_seq_lens.append(seq_len) + else: + assert query_len == 1, ( + "seq_len: {}, context_len: {}, query_len: {}".format( + seq_len, context_len, query_len)) + num_decode_tokens += query_len + decode_seq_lens.append(seq_len) + + if lora_id > 0: + lora_requests.add(seq_group_metadata.lora_request) + + lora_index_mapping += [lora_id] * (seq_len - context_len) + lora_prompt_mapping.extend( + [lora_id] * + (seq_len - + context_len if seq_group_metadata.sampling_params + and seq_group_metadata.sampling_params.prompt_logprobs + else 1)) + + if seq_group_metadata.multi_modal_data: + multi_modal_input_list.append( + seq_group_metadata.multi_modal_data.data) + + if _is_block_tables_empty(seq_group_metadata.block_tables): + # During memory profiling, the block tables are not + # initialized yet. In this case, we just use a dummy + # slot mapping. + # In embeddings, the block tables are {seq_id: None}. + slot_mapping.extend([_PAD_SLOT_ID] * seq_len) + continue + # Compute the slot mapping. block_table = seq_group_metadata.block_tables[seq_id] - block_number = block_table[position // self.block_size] - block_offset = position % self.block_size - slot = block_number * self.block_size + block_offset - slot_mapping.append(slot) - lora_index_mapping.append(lora_id) - lora_prompt_mapping.append(lora_id) + # Mask the [0, start_idx) tokens of the prompt with + # _PAD_SLOT_ID, where start_idx is max(0, seq_len - + # sliding_window). For example, if the prompt len is 10, + # sliding window is 8, and block size is 4, the first two + # tokens are masked and the slot mapping will be + # [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. + start_idx = 0 if self.sliding_window is not None: - sliding_window_blocks = (self.sliding_window // - self.block_size) - block_table = block_table[-sliding_window_blocks:] - block_tables.append(block_table) + if is_prompt: + assert context_len == 0, ( + "Prefix caching is currently not supported with " + "sliding window attention") + # It is an optimization. When it is decoding, it is always + # 0. When prefill, we use it to not write slots to kv cache + # to save memory. + start_idx = max(0, query_len - self.sliding_window) + + for i in range(context_len, seq_len): + if i < start_idx: + slot_mapping.append(_PAD_SLOT_ID) + continue + + block_number = block_table[i // self.block_size] + block_offset = i % self.block_size + slot = block_number * self.block_size + block_offset + slot_mapping.append(slot) - paged_kv_indices.extend(block_table) - paged_kv_indptr.append(paged_kv_indptr[-1] + len(block_table)) - last_page_len = seq_data.get_len() % self.block_size - if last_page_len == 0: - last_page_len = self.block_size - paged_kv_last_page_len.append(last_page_len) + batch_size = len(input_tokens) + max_query_len = max(query_lens) + max_prefill_seq_len = max(prefill_seq_lens, default=0) + max_decode_seq_len = max(decode_seq_lens, default=0) - # vLLM uses cuda graph only for decoding requests. + # If cuda graph can be used, pad tensors accordingly. # See `capture_model` API for more details. - # For decoding requests, batch_size == input_tokens. - batch_size = len(input_tokens) - max_seq_len = max(seq_lens) - use_captured_graph = (not self.model_config.enforce_eager - and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] - and max_seq_len <= self.max_seq_len_to_capture) + # vLLM uses cuda graph only for decoding requests. + use_captured_graph = ( + decode_only and not self.model_config.enforce_eager + and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] + and max_decode_seq_len <= self.max_seq_len_to_capture) if use_captured_graph: graph_batch_size = _get_graph_batch_size(batch_size) assert graph_batch_size >= batch_size @@ -519,18 +431,9 @@ def _prepare_decode( block_tables.append([]) lora_index_mapping.append(0) batch_size = graph_batch_size - - seq_lens_tensor = torch.tensor(seq_lens, - dtype=torch.int, - device=self.device) + num_decode_tokens = batch_size if use_captured_graph: - # When using cuda-graph all these tensors should be - # padded. - assert seq_lens_tensor.shape[0] == len(input_tokens) - assert seq_lens_tensor.shape[0] == len(input_positions) - assert seq_lens_tensor.shape[0] == len(slot_mapping) - # The shape of graph_block_tables is # [max batch size, max context len // block size]. input_block_tables = self.graph_block_tables[:batch_size] @@ -548,6 +451,57 @@ def _prepare_decode( dtype=torch.int, device=self.device, ) + assert max_query_len > 0, ("query_lens: {}".format(query_lens)) + + context_lens_tensor = torch.tensor(context_lens, + dtype=torch.int, + device=self.device) + + if multi_modal_input_list: + assert self.vision_language_config, ( + "Multi-modal inputs are only supported by " + "vision language models.") + multi_modal_input = torch.cat(multi_modal_input_list, + dim=0).to(self.device) + else: + multi_modal_input = None + + seq_lens_tensor = torch.tensor(seq_lens, + dtype=torch.int, + device=self.device) + query_lens_tensor = torch.tensor(query_lens, + dtype=torch.long, + device=self.device) + query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=self.device) + + seq_lens_tensor = torch.tensor(seq_lens, + dtype=torch.int, + device=self.device) + seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=self.device) + + torch.cumsum(query_lens_tensor, + dim=0, + dtype=query_start_loc.dtype, + out=query_start_loc[1:]) + + torch.cumsum(seq_lens_tensor, + dim=0, + dtype=seq_start_loc.dtype, + out=seq_start_loc[1:]) + + input_tokens_tensor = torch.tensor(input_tokens, + dtype=torch.long, + device=self.device) + input_positions_tensor = torch.tensor(input_positions, + dtype=torch.long, + device=self.device) + slot_mapping_tensor = torch.tensor(slot_mapping, + dtype=torch.long, + device=self.device) if self.attn_backend.get_name() == "flashinfer": if not hasattr(self, "flashinfer_workspace_buffer"): @@ -555,53 +509,75 @@ def _prepare_decode( # Follow the example of flashinfer: https://docs.flashinfer.ai/api/python/decode.html self.flashinfer_workspace_buffer = torch.empty( 16 * 1024 * 1024, dtype=torch.uint8, device=self.device) - paged_kv_indptr = torch.tensor(paged_kv_indptr, - dtype=torch.int, - device=self.device) - paged_kv_indices = torch.tensor(paged_kv_indices, - dtype=torch.int, - device=self.device) - paged_kv_last_page_len = torch.tensor(paged_kv_last_page_len, + paged_kv_indptr_tensor = torch.tensor(paged_kv_indptr, dtype=torch.int, device=self.device) + paged_kv_indices_tensor = torch.tensor(paged_kv_indices, + dtype=torch.int, + device=self.device) + paged_kv_last_page_len_tensor = torch.tensor( + paged_kv_last_page_len, dtype=torch.int, device=self.device) kv_cache_dtype = get_kv_cache_torch_dtype(self.kv_cache_dtype, self.model_config.dtype) - attn_metadata = self.attn_backend.make_metadata( - is_prompt=False, + num_prefills=num_prefills, + slot_mapping=slot_mapping_tensor, + num_prefill_tokens=num_prefill_tokens, + num_decode_tokens=num_decode_tokens, use_cuda_graph=False, + max_prefill_seq_len=max_prefill_seq_len, + block_tables=block_tables, workspace_buffer=self.flashinfer_workspace_buffer, - paged_kv_indptr=paged_kv_indptr, - paged_kv_indices=paged_kv_indices, - paged_kv_last_page_len=paged_kv_last_page_len, + paged_kv_indptr=paged_kv_indptr_tensor, + paged_kv_indices=paged_kv_indices_tensor, + paged_kv_last_page_len=paged_kv_last_page_len_tensor, num_qo_heads=self.model_config.get_num_attention_heads( self.parallel_config), num_kv_heads=self.model_config.get_num_kv_heads( self.parallel_config), head_dim=self.model_config.get_head_size(), - page_size=self.block_size, + page_size=16, + seq_start_loc=seq_start_loc, data_type=kv_cache_dtype) else: attn_metadata = self.attn_backend.make_metadata( - is_prompt=False, - seq_lens=None, + num_prefills=num_prefills, + slot_mapping=slot_mapping_tensor, + num_prefill_tokens=num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, - max_query_len=None, - max_seq_len=max_seq_len, - subquery_start_loc=None, - seq_start_loc=None, - context_lens_tensor=None, + max_query_len=max_query_len, + max_prefill_seq_len=max_prefill_seq_len, + max_decode_seq_len=max_decode_seq_len, + query_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, block_tables=block_tables, use_cuda_graph=use_captured_graph, ) - return PrepareDecodeMetadata( - input_tokens=input_tokens, - input_positions=input_positions, + + if self.lora_config: + lora_mapping = LoRAMapping( + lora_index_mapping, + lora_prompt_mapping, + ) + else: + lora_mapping = None + + return ModelInput( + input_tokens=input_tokens_tensor, + input_positions=input_positions_tensor, attn_metadata=attn_metadata, - lora_index_mapping=lora_index_mapping, - lora_prompt_mapping=lora_prompt_mapping, + seq_lens=seq_lens, + query_lens=query_lens, + lora_mapping=lora_mapping, lora_requests=lora_requests, - slot_mapping=slot_mapping, + multi_modal_input=multi_modal_input, + slot_mapping=slot_mapping_tensor, + num_prefill_tokens=num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + num_prefills=num_prefills, ) def prepare_input_tensors( @@ -610,85 +586,25 @@ def prepare_input_tensors( ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata, Set[LoRARequest], LoRAMapping, torch.Tensor]: if self.is_driver_worker: - prefill_reqs = [] - decode_reqs = [] - for seq_group_meta in seq_group_metadata_list: - if seq_group_meta.is_prompt: - prefill_reqs.append(seq_group_meta) - else: - decode_reqs.append(seq_group_meta) - # Prepare input tensors. ( input_tokens, input_positions, - prefill_attn_metadata, + attn_metadata, seq_lens, query_lens, - lora_index_mapping, - lora_prompt_mapping, + lora_mapping, lora_requests, multi_modal_input, slot_mapping, - ) = self._prepare_prompt(prefill_reqs) - ( - decode_input_tokens, - decode_input_positions, - decode_attn_metadata, - decode_lora_index_mapping, - decode_lora_prompt_mapping, - decode_lora_requests, - decode_slot_mapping, - ) = self._prepare_decode(decode_reqs) + num_prefill_tokens, + num_decode_tokens, + num_prefills, + ) = self._prepare_model_input(seq_group_metadata_list) sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, seq_lens, query_lens, self.device, self.pin_memory) - if not self.scheduler_config.chunked_prefill_enabled: - assert (len(prefill_reqs) and len(decode_reqs)) == 0 - - num_prefills = len(seq_lens) - num_prefill_tokens = len(input_tokens) - num_decode_tokens = len(decode_input_tokens) - - # Coalesce tensors. Note that attn_metadata is currently not - # coalesced for simplicity. - input_tokens.extend(decode_input_tokens) - input_positions.extend(decode_input_positions) - slot_mapping.extend(decode_slot_mapping) - lora_index_mapping.extend(decode_lora_index_mapping) - lora_prompt_mapping.extend(decode_lora_prompt_mapping) - lora_requests.update(decode_lora_requests) - - input_tokens = torch.tensor(input_tokens, - dtype=torch.long, - device=self.device) - input_positions = torch.tensor(input_positions, - dtype=torch.long, - device=self.device) - slot_mapping = torch.tensor(slot_mapping, - dtype=torch.long, - device=self.device) - - if self.lora_config: - lora_mapping = LoRAMapping( - lora_index_mapping, - lora_prompt_mapping, - ) - else: - lora_mapping = None - - # Broadcast the metadata. - # If batch contains both prefill and decode, it sends 2 broadcasts. - # If it only contains 1 type, it triggers a single broadcast. - if (prefill_attn_metadata is not None - and decode_attn_metadata is not None): - batch_type = BatchType.MIXED - elif prefill_attn_metadata is not None: - batch_type = BatchType.PREFILL - else: - batch_type = BatchType.DECODE - metadata_dict = { "input_tokens": input_tokens, "input_positions": input_positions, @@ -701,46 +617,24 @@ def prepare_input_tensors( "num_decode_tokens": num_decode_tokens, "slot_mapping": slot_mapping, "num_prefills": num_prefills, - "batch_type": batch_type, } - if prefill_attn_metadata is not None: - metadata_dict.update(prefill_attn_metadata.asdict_zerocopy()) - else: - assert decode_attn_metadata is not None - metadata_dict.update(decode_attn_metadata.asdict_zerocopy()) + if attn_metadata: + metadata_dict.update(attn_metadata.asdict_zerocopy()) broadcast_tensor_dict(metadata_dict, src=0) - - # Broadcast decode attn metadata for mixed batch type. - # The additional broadcast costs 300us overhead on 4 A10 GPUs. - # We can potentially reduce the overhead by coelescing tensors. - if batch_type == BatchType.MIXED: - assert decode_attn_metadata is not None - metadata_dict = decode_attn_metadata.asdict_zerocopy() - broadcast_tensor_dict(metadata_dict, src=0) else: metadata_dict = broadcast_tensor_dict(src=0) input_tokens = metadata_dict.pop("input_tokens") input_positions = metadata_dict.pop("input_positions") - slot_mapping = metadata_dict.pop("slot_mapping") - num_prefills = metadata_dict.pop("num_prefills") selected_token_indices = metadata_dict.pop( "selected_token_indices") lora_mapping = metadata_dict.pop("lora_mapping") lora_requests = metadata_dict.pop("lora_requests") multi_modal_input = metadata_dict.pop("multi_modal_input") - num_prefill_tokens = metadata_dict.pop("num_prefill_tokens") - num_decode_tokens = metadata_dict.pop("num_decode_tokens") - batch_type = metadata_dict.pop("batch_type") - - # Create an attention metadata. - prefill_attn_metadata = None - decode_attn_metadata = None - if batch_type == BatchType.PREFILL or batch_type == BatchType.MIXED: - prefill_attn_metadata = self.attn_backend.make_metadata( + if metadata_dict: + attn_metadata = self.attn_backend.make_metadata( **metadata_dict) else: - decode_attn_metadata = self.attn_backend.make_metadata( - **metadata_dict) + attn_metadata = None sampling_metadata = SamplingMetadata( seq_groups=None, selected_token_indices=selected_token_indices, @@ -748,22 +642,6 @@ def prepare_input_tensors( num_prompts=0, ) - # if it is a mixed batch, decode attn_metadata is broadcasted - # separately. - if batch_type == BatchType.MIXED: - metadata_dict = broadcast_tensor_dict(src=0) - decode_attn_metadata = self.attn_backend.make_metadata( - **metadata_dict) - - attn_metadata = AttentionMetadata( - num_prefills=num_prefills, - slot_mapping=slot_mapping, - num_prefill_tokens=num_prefill_tokens, - num_decode_tokens=num_decode_tokens, - prefill_metadata=prefill_attn_metadata, - decode_metadata=decode_attn_metadata, - ) - return (input_tokens, input_positions, attn_metadata, sampling_metadata, lora_requests, lora_mapping, multi_modal_input) @@ -954,26 +832,22 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: # memory usage of CUDA graph. for batch_size in reversed(batch_size_capture_list): # Create dummy attn_metadata. - decode_metadata = self.attn_backend.make_metadata( - is_prompt=False, + attn_metadata = self.attn_backend.make_metadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=batch_size, + slot_mapping=slot_mapping[:batch_size], seq_lens=None, seq_lens_tensor=seq_lens[:batch_size], max_query_len=None, - max_seq_len=self.max_seq_len_to_capture, - subquery_start_loc=None, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_seq_len_to_capture, + query_start_loc=None, seq_start_loc=None, context_lens_tensor=None, block_tables=block_tables[:batch_size], use_cuda_graph=True, ) - attn_metadata = AttentionMetadata( - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=batch_size, - slot_mapping=slot_mapping[:batch_size], - prefill_metadata=None, - decode_metadata=decode_metadata, - ) if self.lora_config: lora_mapping = LoRAMapping( From e9cdd2b1e20beb1c21c55441d0e6a4ed86f4e292 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 15 May 2024 14:38:40 +0800 Subject: [PATCH 3/4] [CI/Build] Further decouple HuggingFace implementation from ours during tests (#4166) --- tests/conftest.py | 77 +++++++++++++++++++++++++---------------------- 1 file changed, 41 insertions(+), 36 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 999ace2c3c699..c1a44a606e1bf 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,19 +1,21 @@ import contextlib import gc import os -from typing import List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import pytest import torch from PIL import Image -from transformers import (AutoModelForCausalLM, AutoProcessor, - LlavaForConditionalGeneration) +from transformers import (AutoModelForCausalLM, AutoProcessor, AutoTokenizer, + LlavaConfig, LlavaForConditionalGeneration) from vllm import LLM, SamplingParams from vllm.config import TokenizerPoolConfig, VisionLanguageConfig from vllm.distributed import destroy_model_parallel +from vllm.logger import init_logger from vllm.sequence import MultiModalData -from vllm.transformers_utils.tokenizer import get_tokenizer + +logger = init_logger(__name__) _TEST_DIR = os.path.dirname(__file__) _TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")] @@ -129,9 +131,7 @@ def example_long_prompts() -> List[str]: "float": torch.float, } -_VISION_LANGUAGE_MODELS = { - "llava-hf/llava-1.5-7b-hf": LlavaForConditionalGeneration, -} +AutoModelForCausalLM.register(LlavaConfig, LlavaForConditionalGeneration) _EMBEDDING_MODELS = [ "intfloat/e5-mistral-7b-instruct", @@ -143,23 +143,14 @@ class HfRunner: def __init__( self, model_name: str, - tokenizer_name: Optional[str] = None, dtype: str = "half", ) -> None: assert dtype in _STR_DTYPE_TO_TORCH_DTYPE torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype] + self.model_name = model_name - if model_name in _VISION_LANGUAGE_MODELS: - self.model = _VISION_LANGUAGE_MODELS[model_name].from_pretrained( - model_name, - torch_dtype=torch_dtype, - trust_remote_code=True, - ).cuda() - self.processor = AutoProcessor.from_pretrained( - model_name, - torch_dtype=torch_dtype, - ) - elif model_name in _EMBEDDING_MODELS: + + if model_name in _EMBEDDING_MODELS: # Lazy init required for AMD CI from sentence_transformers import SentenceTransformer self.model = SentenceTransformer( @@ -172,10 +163,24 @@ def __init__( torch_dtype=torch_dtype, trust_remote_code=True, ).cuda() - self.processor = None - if tokenizer_name is None: - tokenizer_name = model_name - self.tokenizer = get_tokenizer(tokenizer_name, trust_remote_code=True) + + self.tokenizer = AutoTokenizer.from_pretrained( + model_name, + torch_dtype=torch_dtype, + trust_remote_code=True, + ) + + try: + self.processor = AutoProcessor.from_pretrained( + model_name, + torch_dtype=torch_dtype, + trust_remote_code=True, + ) + except Exception: + logger.warning( + "Unable to auto-load processor from HuggingFace for " + "model %s. Using tokenizer instead.", model_name) + self.processor = self.tokenizer def generate( self, @@ -187,19 +192,19 @@ def generate( if images: assert len(prompts) == len(images) for i, prompt in enumerate(prompts): - if self.model_name not in _VISION_LANGUAGE_MODELS: - input_ids = self.tokenizer(prompt, - return_tensors="pt").input_ids - inputs = {"input_ids": input_ids.cuda()} - else: - image = images[i] if images else None - inputs = self.processor(text=prompt, - images=image, - return_tensors="pt") - inputs = { - key: value.cuda() if value is not None else None - for key, value in inputs.items() - } + processor_kwargs: Dict[str, Any] = { + "text": prompt, + "return_tensors": "pt", + } + if images is not None and images[i] is not None: + processor_kwargs["images"] = images[i] + + inputs = self.processor(**processor_kwargs) + inputs = { + key: value.cuda() if value is not None else None + for key, value in inputs.items() + } + output_ids = self.model.generate( **inputs, use_cache=True, From a5675d348b126e53928e139d1ed5b2c00a0044e8 Mon Sep 17 00:00:00 2001 From: zifeitong Date: Wed, 15 May 2024 07:22:09 -0700 Subject: [PATCH 4/4] [Bugfix] Properly set distributed_executor_backend in ParallelConfig (#4816) --- vllm/config.py | 1 + vllm/engine/arg_utils.py | 10 +++++++--- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 26edd4567b9ac..2eb5bdd18d812 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -531,6 +531,7 @@ class ParallelConfig: If None, will use synchronous tokenization. ray_workers_use_nsight: Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler. + placement_group: ray distributed model workers placement group. distributed_executor_backend: Backend to use for distributed model workers, either "ray" or "mp" (multiprocessing). If either pipeline_parallel_size or tensor_parallel_size is greater than 1, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index bd44c2470182b..dab86b7c9eb35 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -548,14 +548,18 @@ def create_engine_config(self, ) -> EngineConfig: model_config.get_sliding_window(), self.enable_prefix_caching) parallel_config = ParallelConfig( - self.pipeline_parallel_size, self.tensor_parallel_size, - self.worker_use_ray, self.max_parallel_loading_workers, + self.pipeline_parallel_size, + self.tensor_parallel_size, + self.worker_use_ray, + self.max_parallel_loading_workers, self.disable_custom_all_reduce, TokenizerPoolConfig.create_config( self.tokenizer_pool_size, self.tokenizer_pool_type, self.tokenizer_pool_extra_config, - ), self.ray_workers_use_nsight) + ), + self.ray_workers_use_nsight, + distributed_executor_backend=self.distributed_executor_backend) speculative_config = SpeculativeConfig.maybe_create_spec_config( target_model_config=model_config,