From efbce85f4d375d7851a491a0126a224e25d9f91d Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Mon, 16 Dec 2024 13:14:57 -0500 Subject: [PATCH 001/200] [misc] Layerwise profile updates (#10242) Signed-off-by: Varun Sundar Rabindranath Co-authored-by: Varun Sundar Rabindranath --- .buildkite/test-pipeline.yaml | 2 +- examples/offline_profile.py | 236 +++++++++++++++--- tools/profiler/print_layerwise_table.py | 9 +- tools/profiler/visualize_layerwise_profile.py | 92 ++++++- vllm/profiler/layerwise_profile.py | 22 +- 5 files changed, 314 insertions(+), 47 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 97aae233db105..44f47fac1c1b3 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -201,7 +201,7 @@ steps: - python3 offline_inference_classification.py - python3 offline_inference_embedding.py - python3 offline_inference_scoring.py - - python3 offline_profile.py --model facebook/opt-125m + - python3 offline_profile.py --model facebook/opt-125m run_num_steps --num-steps 2 - label: Prefix Caching Test # 9min mirror_hardwares: [amd] diff --git a/examples/offline_profile.py b/examples/offline_profile.py index 1d415b82cddb6..46afe8aa2604b 100644 --- a/examples/offline_profile.py +++ b/examples/offline_profile.py @@ -4,9 +4,10 @@ import sys from argparse import RawTextHelpFormatter from dataclasses import asdict, dataclass -from typing import Optional +from typing import Any, Dict, Generator, List, Optional, TypeAlias import torch +import tqdm from vllm import LLM, SamplingParams from vllm.engine.arg_utils import EngineArgs @@ -15,16 +16,21 @@ BATCH_SIZE_DEFAULT = 1 PROMPT_LEN_DEFAULT = 256 -OUTPUT_LEN_DEFAULT = 2 @dataclass class ProfileContext: engine_args: EngineArgs prompt_len: int - output_len: int batch_size: int - save_chrome_traces_folder: Optional[str] + + # The profiler can run in 2 modes, + # 1. Run profiler for user specified num_steps + num_steps: Optional[int] = None + # 2. Run profiler until all requests complete + complete_num_requests_per_step: Optional[int] = None + + save_chrome_traces_folder: Optional[str] = None def get_dtype(dtype: str): @@ -34,23 +40,155 @@ def get_dtype(dtype: str): return dtype +OutputLen_NumReqs_Map: TypeAlias = Dict[int, int] +def compute_request_output_lengths(batch_size: int, step_requests: List[int]) \ + -> OutputLen_NumReqs_Map: + """ + Given the number of requests, batch_size, and the number of requests + that each engine-step should process, step_requests, determine the + output lengths of the requests such that step_request is honoured. + + Example: + if batch size = 128 and step_request = [128, 128, 96, 64, 32, 1] + then return, + {2 : 32, 3 : 32, 4 : 32, 5 : 31, 6 : 1}, meaning, + 32 requests should have output length 2, + 32 requests should have output length 3, + 32 requests should have output length 4, + 31 requests should have output length 5, + 1 request should have output length 6. + + Args: + batch_size (int): Number of requests submitted for profile. This is + args.batch_size. + step_requests (List[int]): step_requests[i] is the number of requests + that the ith engine step should process. + + Returns: + OutputLen_NumReqs_Map : A dictionary with output-length as keys and the + number of requests required to have that output-length as values. + """ + ol_nr: OutputLen_NumReqs_Map = {} + + # Number of request that are assigned an output-length + num_reqs_assigned: int = 0 + num_steps: int = len(step_requests) + + # sanity check. The first step (prefill-step), must process all requests. + assert step_requests[0] == batch_size + + # Begin assignments from the last step. + output_length: int = num_steps + for num_requests_at_step in reversed(step_requests): + if num_reqs_assigned == batch_size: + break + + assert num_reqs_assigned < batch_size + + # Remove the number of requests that have been determined + # to participate in this step and beyond. + num_reqs_unassigned_at_step = num_requests_at_step - num_reqs_assigned + assert num_reqs_unassigned_at_step >= 0 + + if num_reqs_unassigned_at_step > 0: + ol_nr[output_length] = num_reqs_unassigned_at_step + num_reqs_assigned += num_reqs_unassigned_at_step + + output_length -= 1 + + # sanity checks. + assert sum(ol_nr.values()) == batch_size, \ + ("Number of requests in output-length assignment does not match " + f"batch-size.\n batch size {batch_size} - " + f"step requests {step_requests} - assignments {ol_nr}") + + # Check that the output-length is in [1, num-steps]. Output length must be + # at least 1 as all requests must participate in the prefill-step. + assert all(ol >= 1 and ol <= num_steps for ol in ol_nr), \ + ("Output lengths of requests should be in range " + f"[1, num-engine-steps].\n batch size {batch_size} - " + f"step requests {step_requests} - assignments {ol_nr}") + + return ol_nr + + +def determine_requests_per_step(context: ProfileContext) -> List[int]: + """ + Determine number of requests each engine step should process. + If context.num_steps is set, then all engine steps process the + same number of requests and the output list is of length + context.num_steps. + + If context.complete_num_requests_per_step is set, then each decode step + processes fewer and fewer requests until there are no requests to process. + In this case, the output list is as big as the number of steps + required to process all requests. + + Args: + context: ProfileContext object. + + Returns: + List[int]: Number of requests to process for all engine-steps. + output[i], contains the number of requests that the ith step + should process. + """ + if context.num_steps: + # All requests must run until num_engine_steps. This implies + # that their output lengths must be equal to num_engine_steps. + return [context.batch_size] * context.num_steps + + assert context.complete_num_requests_per_step and \ + context.complete_num_requests_per_step > 0, \ + (f"Expected a positive complete_num_requests_per_step argument." + f"Instead got {context.complete_num_requests_per_step}") + + # We start dropping after the first decode step. + step_requests = [ + context.batch_size, # prefill + context.batch_size, # decode + ] + + num_running_requests = context.batch_size + num_running_requests -= context.complete_num_requests_per_step + while num_running_requests > 0: + step_requests.append(num_running_requests) + num_running_requests -= context.complete_num_requests_per_step + + if step_requests[-1] != 1: + # have 1 request running at the last step. This is often + # useful + step_requests.append(1) + + return step_requests + + def run_profile(context: ProfileContext, csv_output: Optional[str], json_output: Optional[str]): print("Run profile with:") for key, value in asdict(context).items(): print(f" {key} = {value}") + requests_per_step: List[int] = determine_requests_per_step(context) + + ol_nr: OutputLen_NumReqs_Map = compute_request_output_lengths( + context.batch_size, requests_per_step) + + num_steps_to_profile: int = len(requests_per_step) + max_output_len: int = max(ol_nr.keys()) + assert max_output_len >= 1 + # Create sampling params - sampling_params = SamplingParams(temperature=0.8, - top_p=0.95, - max_tokens=args.output_len, - ignore_eos=True) + sampling_params = SamplingParams( + temperature=0.8, + top_p=0.95, + # max_tokens is set on a per-request basis. + max_tokens=None, + ignore_eos=True) # Create LLM llm = LLM(**asdict(context.engine_args)) batch_size = context.batch_size prompt_len = context.prompt_len - output_len = context.output_len scheduler_config = llm.llm_engine.scheduler_config max_model_len = llm.llm_engine.model_config.max_model_len @@ -65,7 +203,7 @@ def run_profile(context: ProfileContext, csv_output: Optional[str], f"choose a smaller batch size or prompt length, or increase " f"--max-num-batched-tokens") sys.exit(-1) - if batch_size >= max_num_seqs: + if batch_size > max_num_seqs: print( f"ERROR: chosen batch_size ({batch_size}) is larger than " f"max_num_seqs ({max_num_seqs}) and therefore cannot be run in a " @@ -73,16 +211,26 @@ def run_profile(context: ProfileContext, csv_output: Optional[str], sys.exit(-1) print("llm.llm_engine.model_config.max_model_len: ", llm.llm_engine.model_config.max_model_len) - if prompt_len + output_len > llm.llm_engine.model_config.max_model_len: - print( - f"ERROR: chosen prompt_len + output_len ({prompt_len} + " - f"{output_len} = {prompt_len + output_len}) is larger than the " - f"model's max_model_len ({max_model_len}), please choose a smaller " - f"prompt_len or output_len, or increase --max-model-len") + if prompt_len + max_output_len > llm.llm_engine.model_config.max_model_len: + print(f"ERROR: chosen prompt_len + max_output_len ({prompt_len} + " + f"{max_output_len} = {prompt_len + max_output_len}) is larger " + f"than the model's max_model_len ({max_model_len}), please " + f"choose a smaller prompt_len or max_output_len, or increase " + f"--max-model-len") sys.exit(-1) def add_requests(): + + def get_output_len_generator() -> Generator[int, Any, Any]: + for output_len, num_reqs in ol_nr.items(): + for _ in range(num_reqs): + yield output_len + + output_len_generator = get_output_len_generator() for i in range(batch_size): + sampling_params.max_tokens = next(output_len_generator) + assert isinstance(sampling_params.max_tokens, int) + prompt_token_ids = torch.randint( llm.llm_engine.model_config.get_vocab_size(), size=(prompt_len, )).tolist() @@ -110,8 +258,11 @@ def abort_requests(): llm.llm_engine.step() # First step is prefill decode_profs = [] - for x in range(args.output_len - 1): - with layerwise_profile() as decode_prof: + for _ in tqdm.tqdm(range(num_steps_to_profile - 1)): + num_running_seqs = llm.llm_engine.scheduler[ + 0].get_num_unfinished_seq_groups() + with layerwise_profile( + num_running_seqs=num_running_seqs) as decode_prof: llm.llm_engine.step() decode_profs.append(decode_prof) @@ -154,7 +305,8 @@ def abort_requests(): decode_results_list[0].print_summary_table() if csv_output: - csv_filename_base = csv_output.rstrip(".csv") + csv_filename_base = csv_output[:-4] \ + if csv_output.endswith('.csv') else csv_output prefill_results.export_model_stats_table_csv( csv_filename_base + "_prefill_model_table.csv") prefill_results.export_summary_stats_table_csv( @@ -187,10 +339,10 @@ def abort_requests(): for idx, dr in enumerate(decode_results_list): json_dict[f"decode_{idx + 1}"] = dr.convert_stats_to_dict() - for idx, dr in enumerate(decode_results_list[1:]): - json_dict[f"decode_{idx + 1}"] = dr.convert_stats_to_dict() - - with open(json_output.rstrip(".json") + ".json", "w+") as f: + # Add .json to json_output filename if it doesn't exist already. + json_output_file = json_output if json_output.endswith( + '.json') else json_output + '.json' + with open(json_output_file, "w+") as f: json.dump(json_dict, f, indent=2) pass @@ -214,7 +366,7 @@ def abort_requests(): python examples/offline_profile.py \\ --model neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8 --batch-size 4 \\ --prompt-len 512 --max-num-batched-tokens 8196 --json Llama31-8b-FP8 \\ - --enforce-eager + --enforce-eager run_num_steps -n 2 ``` then you can use various tools to analyze the json output @@ -261,17 +413,41 @@ def abort_requests(): default=BATCH_SIZE_DEFAULT, help=f"Number of requests to run as a single batch, " f"default={BATCH_SIZE_DEFAULT}") - parser.add_argument( - "--output-len", + + subparsers = parser.add_subparsers(dest="cmd") + + run_num_steps_parser = subparsers.add_parser( + "run_num_steps", + help="This variation profiles n engine.step() invocations.") + run_num_steps_parser.add_argument( + '-n', + '--num-steps', type=int, - default=OUTPUT_LEN_DEFAULT, - help="Number of llm steps to run (includes prefill and decode) " - "- default={OUTPUT_LEN_DEFAULT}") + help="Number of engine steps to profile.\n" + "Setting it to 1, profiles only the prefill step.\n" + "Setting it to 2, profiles the prefill and first decode step\n" + "Setting it to 3, profiles the prefill, 1st and 2nd decode steps\n" + "and so on ...") + + run_to_completion_parser = subparsers.add_parser( + "run_to_completion", + help="This variation profiles all the engine.step() invocations" + "until the engine exhausts all submitted requests.") + run_to_completion_parser.add_argument( + '-n', + '--complete-num-requests-per-step', + type=int, + help= + "Complete complete_num_requests_per_step requests every decode step." + "For e.g., with batch_size 128 and complete_num_requests_per_step 32," + "the profiler is run for 6 engine steps, with the steps processing, " + "128, 128, 96, 64, 32, 1 requests respectively.\n" + "Note that we tack-on a one-request step at the end as it is often " + "useful.") EngineArgs.add_cli_args(parser) args = parser.parse_args() - context = ProfileContext( engine_args=EngineArgs.from_cli_args(args), **{ diff --git a/tools/profiler/print_layerwise_table.py b/tools/profiler/print_layerwise_table.py index 081076ad7dbdc..394ca8663e189 100644 --- a/tools/profiler/print_layerwise_table.py +++ b/tools/profiler/print_layerwise_table.py @@ -34,9 +34,10 @@ def get_entries(node, curr_depth=0): "examples/offline_profile.py") parser.add_argument("--phase", type=str, - choices=["prefill", "decode_1"], required=True, - help="The phase to print the table for.") + help="The phase to print the table for. This is either" + "prefill or decode_n, where n is the decode step " + "number") parser.add_argument("--table", type=str, choices=["summary", "model"], @@ -49,6 +50,10 @@ def get_entries(node, curr_depth=0): with open(args.json_trace) as f: profile_data = json.load(f) + assert args.phase in profile_data, \ + (f"Cannot find phase {args.phase} in profile data. Choose one among" + f'{[x for x in profile_data.keys() if "prefill" in x or "decode" in x]}') #noqa + if args.table == "summary": entries_and_depths = flatten_entries( SummaryStatsEntry, profile_data[args.phase]["summary_stats"]) diff --git a/tools/profiler/visualize_layerwise_profile.py b/tools/profiler/visualize_layerwise_profile.py index adc44474aa4c1..da7a28da15c19 100644 --- a/tools/profiler/visualize_layerwise_profile.py +++ b/tools/profiler/visualize_layerwise_profile.py @@ -151,16 +151,31 @@ def is_quant(op_name: str): "scaled_int8_quant" in op_name: return True + # LoRA ops + def is_sgmv_shrink(op_name: str): + return "sgmv_shrink" in op_name + + def is_sgmv_expand(op_name: str): + return "sgmv_expand" in op_name + + def is_bgmv_shrink(op_name: str): + return "bgmv_shrink" in op_name + + def is_bgmv_expand(op_name: str): + return "bgmv_expand" in op_name + + def is_cutlass_gemm_op(op_name: str): + return "void cutlass::Kernel" in op_name or \ + "void cutlass::device_kernel" in op_name + def is_gemm_op(op_name: str): if is_quant(op_name): return False - if "xmma_gemm" in op_name or \ + return is_cutlass_gemm_op(op_name) or \ + "xmma_gemm" in op_name or \ "gemv2T_kernel" in op_name or \ "splitKreduce" in op_name or \ - "void cutlass::Kernel" in op_name or \ - "void cutlass::device_kernel" in op_name or \ - "s16816gemm" in op_name: - return True + "s16816gemm" in op_name def is_elementwise_op(op_name: str): return "elementwise_kernel" in op_name @@ -211,6 +226,18 @@ def is_reduce_kernel(op_name: str): quant_ops = list(filter(lambda x: is_quant(x), ops)) ops = list(filter(lambda x: x not in quant_ops, ops)) + sgmv_shrink_ops = list(filter(lambda x: is_sgmv_shrink(x), ops)) + ops = list(filter(lambda x: x not in sgmv_shrink_ops, ops)) + sgmv_expand_ops = list(filter(lambda x: is_sgmv_expand(x), ops)) + ops = list(filter(lambda x: x not in sgmv_expand_ops, ops)) + bgmv_shrink_ops = list(filter(lambda x: is_bgmv_shrink(x), ops)) + ops = list(filter(lambda x: x not in bgmv_shrink_ops, ops)) + bgmv_expand_ops = list(filter(lambda x: is_bgmv_expand(x), ops)) + ops = list(filter(lambda x: x not in bgmv_expand_ops, ops)) + + cutlass_gemm_ops = list(filter(lambda x: is_cutlass_gemm_op(x), ops)) + ops = list(filter(lambda x: x not in cutlass_gemm_ops, ops)) + gemm_ops = list(filter(lambda x: is_gemm_op(x), ops)) ops = list(filter(lambda x: x not in gemm_ops, ops)) @@ -257,6 +284,24 @@ def is_reduce_kernel(op_name: str): trace_df['attention'] = trace_df[attention_ops].agg("sum", axis=1) if len(quant_ops): trace_df['quant_ops'] = trace_df[quant_ops].agg("sum", axis=1) + + if len(sgmv_shrink_ops): + trace_df['sgmv_shrink_ops'] = trace_df[sgmv_shrink_ops].agg("sum", + axis=1) + if len(sgmv_expand_ops): + trace_df['sgmv_expand_ops'] = trace_df[sgmv_expand_ops].agg("sum", + axis=1) + if len(bgmv_shrink_ops): + trace_df['bgmv_shrink_ops'] = trace_df[bgmv_shrink_ops].agg("sum", + axis=1) + if len(bgmv_expand_ops): + trace_df['bgmv_expand_ops'] = trace_df[bgmv_expand_ops].agg("sum", + axis=1) + + if len(cutlass_gemm_ops): + trace_df['cutlass_gemm_ops'] = trace_df[cutlass_gemm_ops].agg("sum", + axis=1) + if len(gemm_ops): trace_df['gemm_ops'] = trace_df[gemm_ops].agg("sum", axis=1) if len(rms_norm_ops): @@ -296,7 +341,9 @@ def is_reduce_kernel(op_name: str): trace_df['reduce_kernel_ops'] = trace_df[reduce_kernel_ops].agg("sum", axis=1) - trace_df.drop(attention_ops + quant_ops + gemm_ops + rms_norm_ops + + trace_df.drop(attention_ops + quant_ops + sgmv_shrink_ops + + sgmv_expand_ops + bgmv_shrink_ops + bgmv_expand_ops + + cutlass_gemm_ops + gemm_ops + rms_norm_ops + vocab_embed_ops + mem_ops + elementwise_ops + nccl_all_reduce_ops + nccl_gather_ops + nccl_broadcast_ops + nccl_other_ops + cross_device_reduce_1stage_ops + @@ -315,7 +362,14 @@ def plot_trace_df(traces_df: pd.DataFrame, plot_title: str, output: Optional[Path] = None): + def get_phase_description(traces_df: pd.DataFrame, phase: str) -> str: + phase_df = traces_df.query(f'phase == "{phase}"') + descs = phase_df['phase_desc'].to_list() + assert all([desc == descs[0] for desc in descs]) + return descs[0] + phases = traces_df['phase'].unique() + phase_descs = [get_phase_description(traces_df, p) for p in phases] traces_df = traces_df.pivot_table(index="phase", columns="name", values=plot_metric, @@ -324,7 +378,8 @@ def plot_trace_df(traces_df: pd.DataFrame, traces_df = group_trace_by_operations(traces_df) # Make the figure - fig, ax = plt.subplots(1, figsize=(5, 8), sharex=True) + fig_size_x = max(5, len(phases)) + fig, ax = plt.subplots(1, figsize=(fig_size_x, 8), sharex=True) # Draw the stacked bars ops = list(traces_df) @@ -332,7 +387,7 @@ def plot_trace_df(traces_df: pd.DataFrame, for op in ops: values = [traces_df[op][phase] for phase in phases] values = list(map(lambda x: 0.0 if math.isnan(x) else x, values)) - ax.bar(phases, values, label=op, bottom=bottom) + ax.bar(phase_descs, values, label=op, bottom=bottom) bottom = [bottom[j] + values[j] for j in range(len(phases))] # Write the values as text on the bars @@ -390,6 +445,14 @@ def keep_only_top_entries(df: pd.DataFrame, ["name"]] = "others" return df + def get_phase_description(key: str) -> str: + num_running_seqs = profile_json[key]['metadata'][ + 'num_running_seqs'] + if num_running_seqs is not None: + return f"{key}-seqs-{num_running_seqs}" + else: + return key + # Get data for each key traces = list(map(lambda x: get_entries_and_traces(x), step_keys)) @@ -413,6 +476,7 @@ def keep_only_top_entries(df: pd.DataFrame, # Fill in information about the step-keys for trace_df, step_key in zip(trace_dfs, step_keys): trace_df['phase'] = step_key + trace_df['phase_desc'] = get_phase_description(step_key) # Combine all data frames so they can be put in a single plot traces_df = pd.concat(trace_dfs) @@ -426,12 +490,16 @@ def keep_only_top_entries(df: pd.DataFrame, def make_plot_title_suffix(profile_json: dict) -> str: context = profile_json["context"] sparsity = context.get('sparsity', None) - return (f"{context['model']}\n" + run_type = \ + f'Run {context["num_steps"]} steps' if context['num_steps'] else \ + (f'Complete {context["complete_num_requests_per_step"]} per ' + f'step; Run till completion') + return (f"{context['engine_args']['model']}\n" f"Batch={context['batch_size']}, " f"PromptLen={context['prompt_len']}, " - f"OutputLen={context['output_len']}," - f"NumGpus={context['tensor_parallel_size']}" - f"{', Sparsity ' + sparsity if sparsity else ''}") + f"NumGpus={context['engine_args']['tensor_parallel_size']}" + f"{', Sparsity ' + sparsity if sparsity else ''}\n" + f"Run Type: {run_type}") profile_json = None with open(json_trace) as f: diff --git a/vllm/profiler/layerwise_profile.py b/vllm/profiler/layerwise_profile.py index 9d9f427e807f6..33babfebdca1e 100644 --- a/vllm/profiler/layerwise_profile.py +++ b/vllm/profiler/layerwise_profile.py @@ -72,6 +72,9 @@ class LayerwiseProfileResults(profile): _model_stats_tree: List[_StatsTreeNode] = field(init=False) _summary_stats_tree: List[_StatsTreeNode] = field(init=False) + # profile metadata + num_running_seqs: Optional[int] = None + def __post_init__(self): self._build_correlation_map() self._build_module_tree() @@ -127,6 +130,9 @@ def export_summary_stats_table_csv(self, filename: str): def convert_stats_to_dict(self) -> str: return { + "metadata": { + "num_running_seqs": self.num_running_seqs + }, "summary_stats": self._convert_stats_tree_to_dict(self._summary_stats_tree), "model_stats": @@ -338,7 +344,15 @@ def df_traversal(node: _StatsTreeNode, curr_json_list: List[Dict]): class layerwise_profile(profile): - def __init__(self): + def __init__(self, num_running_seqs: Optional[int] = None): + """ + layerwise profile constructor. + + Args: + num_running_seqs (Optional[int], optional): When given, + num_running_seqs will be passed to LayerProfileResults for metadata + update. Defaults to None. + """ super().__init__( activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, @@ -346,9 +360,13 @@ def __init__(self): with_modules=True, experimental_config=_ExperimentalConfig(verbose=True)) + self.num_running_seqs = num_running_seqs + def __enter__(self): return super().__enter__() def __exit__(self, exc_type, exc_val, exc_tb): super().__exit__(exc_type, exc_val, exc_tb) - self.results = LayerwiseProfileResults(self.profiler.kineto_results) + self.results = LayerwiseProfileResults( + self.profiler.kineto_results, + num_running_seqs=self.num_running_seqs) From 551603feffd9b4ba98ccdd34e02e403e04db88c1 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 16 Dec 2024 13:32:25 -0800 Subject: [PATCH 002/200] [core] overhaul memory profiling and fix backward compatibility (#10511) Signed-off-by: youkaichao --- tests/entrypoints/llm/test_gpu_utilization.py | 25 ++++ tests/entrypoints/llm/test_lazy_outlines.py | 2 +- tests/test_utils.py | 44 +++++- tests/worker/test_profile.py | 18 +-- vllm/engine/arg_utils.py | 11 +- vllm/utils.py | 125 +++++++++++++++++- vllm/worker/multi_step_model_runner.py | 3 +- vllm/worker/worker.py | 68 ++++------ 8 files changed, 236 insertions(+), 60 deletions(-) create mode 100644 tests/entrypoints/llm/test_gpu_utilization.py diff --git a/tests/entrypoints/llm/test_gpu_utilization.py b/tests/entrypoints/llm/test_gpu_utilization.py new file mode 100644 index 0000000000000..c2dab300ecefb --- /dev/null +++ b/tests/entrypoints/llm/test_gpu_utilization.py @@ -0,0 +1,25 @@ +from vllm import LLM, SamplingParams + + +def test_gpu_memory_utilization(): + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + + # makes sure gpu_memory_utilization is per-instance limit, + # not a global limit + llms = [ + LLM(model="facebook/opt-125m", + gpu_memory_utilization=0.3, + enforce_eager=True) for i in range(3) + ] + for llm in llms: + outputs = llm.generate(prompts, sampling_params) + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/tests/entrypoints/llm/test_lazy_outlines.py b/tests/entrypoints/llm/test_lazy_outlines.py index 2c53676c5f5dd..bf609b38a94f5 100644 --- a/tests/entrypoints/llm/test_lazy_outlines.py +++ b/tests/entrypoints/llm/test_lazy_outlines.py @@ -36,7 +36,7 @@ def run_lmfe(sample_regex): llm = LLM(model="facebook/opt-125m", enforce_eager=True, guided_decoding_backend="lm-format-enforcer", - gpu_memory_utilization=0.6) + gpu_memory_utilization=0.3) sampling_params = SamplingParams(temperature=0.8, top_p=0.95) outputs = llm.generate( prompts=[ diff --git a/tests/test_utils.py b/tests/test_utils.py index a731b11eae81c..0bc9e5bc32a46 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -5,11 +5,13 @@ from typing import AsyncIterator, Tuple import pytest +import torch from vllm.utils import (FlexibleArgumentParser, StoreBoolean, deprecate_kwargs, - get_open_port, merge_async_iterators, supports_kw) + get_open_port, memory_profiling, merge_async_iterators, + supports_kw) -from .utils import error_on_warning +from .utils import error_on_warning, fork_new_process_for_each_test @pytest.mark.asyncio @@ -270,3 +272,41 @@ def test_supports_kw(callable,kw_name,requires_kw_only, requires_kw_only=requires_kw_only, allow_var_kwargs=allow_var_kwargs ) == is_supported + + +@fork_new_process_for_each_test +def test_memory_profiling(): + # Fake out some model loading + inference memory usage to test profiling + # Memory used by other processes will show up as cuda usage outside of torch + from vllm.distributed.device_communicators.cuda_wrapper import ( + CudaRTLibrary) + lib = CudaRTLibrary() + # 512 MiB allocation outside of this instance + handle1 = lib.cudaMalloc(512 * 1024 * 1024) + + baseline_memory_in_bytes = \ + torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0] + + # load weights + + weights = torch.randn(128, 1024, 1024, device='cuda', dtype=torch.float32) + + weights_memory_in_bytes = 128 * 1024 * 1024 * 4 # 512 MiB + + with memory_profiling(baseline_memory_in_bytes=baseline_memory_in_bytes, + weights_memory_in_bytes=weights_memory_in_bytes) as result: + # make a memory spike, 1 GiB + spike = torch.randn(256, 1024, 1024, device='cuda', dtype=torch.float32) + del spike + + # Add some extra non-torch memory 256 MiB (simulate NCCL) + handle2 = lib.cudaMalloc(256 * 1024 * 1024) + + # Check that the memory usage is within 5% of the expected values + non_torch_ratio = result.non_torch_increase_in_bytes / (256 * 1024 * 1024) # noqa + torch_peak_ratio = result.torch_peak_increase_in_bytes / (1024 * 1024 * 1024) # noqa + assert abs(non_torch_ratio - 1) <= 0.05 + assert abs(torch_peak_ratio - 1) <= 0.05 + del weights + lib.cudaFree(handle1) + lib.cudaFree(handle2) diff --git a/tests/worker/test_profile.py b/tests/worker/test_profile.py index 194ea2aa506f4..79233c75714de 100644 --- a/tests/worker/test_profile.py +++ b/tests/worker/test_profile.py @@ -31,10 +31,6 @@ def test_gpu_memory_profiling(): is_driver_worker=True, ) - # Load the model so we can profile it - worker.init_device() - worker.load_model() - # Set 10GiB as the total gpu ram to be device-agnostic def mock_mem_info(): current_usage = torch.cuda.memory_stats( @@ -46,20 +42,24 @@ def mock_mem_info(): from unittest.mock import patch with patch("torch.cuda.mem_get_info", side_effect=mock_mem_info): + # Load the model so we can profile it + worker.init_device() + worker.load_model() gpu_blocks, _ = worker.determine_num_available_blocks() - # Peak vram usage by torch should be 0.7077 GiB + # Peak vram usage by torch should be 0.47 GiB + # Model weights take 0.25 GiB # No memory should be allocated outside of torch # 9.0 GiB should be the utilization target - # 8.2923 GiB should be available for the KV cache + # 8.28 GiB should be available for the KV cache block_size = CacheEngine.get_cache_block_size( engine_config.cache_config, engine_config.model_config, engine_config.parallel_config) - expected_blocks = (8.2923 * 1024**3) // block_size + expected_blocks = (8.28 * 1024**3) // block_size # Check within a small tolerance for portability # Hardware, kernel, or dependency changes could all affect memory # utilization. - # A 10 block tolerance here should be about 6MB of wiggle room. - assert abs(gpu_blocks - expected_blocks) < 10 + # A 100 block tolerance here should be about 60MB of wiggle room. + assert abs(gpu_blocks - expected_blocks) < 100 diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 0aa367a173b6c..06b8542779dc0 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -487,11 +487,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: help='The fraction of GPU memory to be used for the model ' 'executor, which can range from 0 to 1. For example, a value of ' '0.5 would imply 50%% GPU memory utilization. If unspecified, ' - 'will use the default value of 0.9. This is a global gpu memory ' - 'utilization limit, for example if 50%% of the gpu memory is ' - 'already used before vLLM starts and --gpu-memory-utilization is ' - 'set to 0.9, then only 40%% of the gpu memory will be allocated ' - 'to the model executor.') + 'will use the default value of 0.9. This is a per-instance ' + 'limit, and only applies to the current vLLM instance.' + 'It does not matter if you have another vLLM instance running ' + 'on the same GPU. For example, if you have two vLLM instances ' + 'running on the same GPU, you can set the GPU memory utilization ' + 'to 0.5 for each instance.') parser.add_argument( '--num-gpu-blocks-override', type=int, diff --git a/vllm/utils.py b/vllm/utils.py index 45e682ac15782..73d2ae25f15ca 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -23,10 +23,12 @@ from asyncio import FIRST_COMPLETED, AbstractEventLoop, Future, Task from collections import UserDict, defaultdict from collections.abc import Iterable, Mapping +from dataclasses import dataclass, field from functools import lru_cache, partial, wraps from typing import (TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, - Dict, Generic, Hashable, List, Literal, Optional, - OrderedDict, Set, Tuple, Type, TypeVar, Union, overload) + Dict, Generator, Generic, Hashable, List, Literal, + Optional, OrderedDict, Set, Tuple, Type, TypeVar, Union, + overload) from uuid import uuid4 import numpy as np @@ -1664,3 +1666,122 @@ def kill_process_tree(pid: int): # Finally kill the parent with contextlib.suppress(ProcessLookupError): os.kill(pid, signal.SIGKILL) + + +@dataclass +class MemorySnapshot: + """Memory snapshot.""" + torch_peak_in_bytes: int = 0 + torch_memory_in_bytes: int = 0 + timestamp: float = 0.0 + + def measure(self): + self.torch_peak_in_bytes = torch.cuda.memory_stats( + )["allocated_bytes.all.peak"] + self.torch_memory_in_bytes = torch.cuda.memory_stats( + )["allocated_bytes.all.current"] + self.timestamp = time.time() + + def __sub__(self, other: "MemorySnapshot") -> "MemorySnapshot": + """support a - b""" + return MemorySnapshot( + torch_peak_in_bytes=self.torch_peak_in_bytes - + other.torch_peak_in_bytes, + torch_memory_in_bytes=self.torch_memory_in_bytes - + other.torch_memory_in_bytes, + timestamp=self.timestamp - other.timestamp) + + +@dataclass +class MemoryProfilingResult: + """Memory profiling result. + """ # noqa + baseline_memory_in_bytes: int = 0 + non_kv_cache_memory_in_bytes: int = 0 + torch_peak_increase_in_bytes: int = 0 + non_torch_increase_in_bytes: int = 0 + weights_memory_in_bytes: float = 0 + before_profile: MemorySnapshot = field(default_factory=MemorySnapshot) + after_profile: MemorySnapshot = field(default_factory=MemorySnapshot) + profile_time: float = 0.0 + + +@contextlib.contextmanager +def memory_profiling( + baseline_memory_in_bytes: int, weights_memory_in_bytes: int +) -> Generator[MemoryProfilingResult, None, None]: + """Memory profiling context manager. + baseline_memory_in_bytes: memory used by all the components other than + the current vLLM instance. It contains: memory used by other processes, memory + used by another vLLM instance in the same process, etc. It is usually measured + before the current vLLM instance initialize the device. And we assume it is + constant during the profiling of the current vLLM instance. + weights_memory_in_bytes: memory used by PyTorch when loading the model weights. + Note that, before loading the model weights, we also initialize the device + and distributed environment, which may consume some memory. This part is not + included in the weights_memory_in_bytes because PyTorch does not control it. + + The memory in one GPU can be classified into 3 categories: + 1. memory used by anything other than the current vLLM instance. + 2. memory used by torch in the current vLLM instance. + 3. memory used in the current vLLM instance, but not by torch. + + A quantitive example: + + Before creating the current vLLM instance: + category 1: 1 GiB + category 2: 0 GiB + category 3: 0 GiB + + After creating the current vLLM instance and loading the model, + (i.e. before profiling): + category 1: 1 GiB + category 2: 2 GiB (model weights take 2 GiB) + category 3: 0.5 GiB (memory used by NCCL) + + During profiling (peak): + category 1: 1 GiB + category 2: 4 GiB (peak activation tensors take 2 GiB) + category 3: 1 GiB (memory used by NCCL + buffers for some attention backends) + + After profiling: + category 1: 1 GiB + category 2: 3 GiB (after garbage-collecting activation tensors) + category 3: 1 GiB (memory used by NCCL + buffers for some attention backends) + + In this case, non-kv cache takes 5 GiB in total, including: + a. 2 GiB used by the model weights (category 2) + b. 2 GiB reserved for the peak activation tensors (category 2) + c. 1 GiB used by non-torch components (category 3) + + The memory used for loading weights (a.) is directly given from the argument `weights_memory_in_bytes`. + + The increase of ``torch.cuda.memory_stats()["allocated_bytes.all.peak"]` after profiling gives (b.). + + (c.) is tricky. We measure the total memory used in this GPU (`torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]`), + subtract the baseline memory, the memory used by the model weights, and diff of `torch.cuda.memory_stats()["allocated_bytes.all.current"]`. + """ # noqa + torch.cuda.reset_peak_memory_stats() + + result = MemoryProfilingResult() + + result.baseline_memory_in_bytes = baseline_memory_in_bytes + # the part of memory used for holding the model weights + result.weights_memory_in_bytes = weights_memory_in_bytes + + result.before_profile.measure() + + yield result + + gc.collect() + torch.cuda.empty_cache() + + result.after_profile.measure() + + diff = result.after_profile - result.before_profile + result.torch_peak_increase_in_bytes = diff.torch_peak_in_bytes + current_cuda_memory_bytes = torch.cuda.mem_get_info( + )[1] - torch.cuda.mem_get_info()[0] + result.non_torch_increase_in_bytes = current_cuda_memory_bytes - baseline_memory_in_bytes - weights_memory_in_bytes - diff.torch_memory_in_bytes # noqa + result.profile_time = diff.timestamp + result.non_kv_cache_memory_in_bytes = result.non_torch_increase_in_bytes + result.torch_peak_increase_in_bytes + result.weights_memory_in_bytes # noqa diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index e08a61e31fe42..18b03bf1bfb56 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -645,7 +645,8 @@ def _advance_step(self, model_input: StatefulModelInput, return model_input def load_model(self) -> None: - return self._base_model_runner.load_model() + self._base_model_runner.load_model() + self.model_memory_usage = self._base_model_runner.model_memory_usage def save_sharded_state( self, diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index a368bb9ee9a5b..f51b51d433d3d 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -1,7 +1,6 @@ """A GPU worker class.""" import gc import os -import time from typing import Dict, List, Optional, Set, Tuple, Type, Union import torch @@ -22,6 +21,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, SequenceGroupMetadata, SequenceGroupMetadataDelta) +from vllm.utils import GiB_bytes, memory_profiling from vllm.worker.cache_engine import CacheEngine from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner @@ -192,33 +192,22 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: torch.cuda.reset_peak_memory_stats() free_memory_pre_profile, total_gpu_memory = torch.cuda.mem_get_info() - start_time = time.time() # Execute a forward pass with dummy inputs to profile the memory usage # of the model. - self.model_runner.profile_run() - torch.cuda.synchronize() + with memory_profiling(baseline_memory_in_bytes=total_gpu_memory - + self.init_gpu_memory, + weights_memory_in_bytes=self.model_runner. + model_memory_usage) as result: + self.model_runner.profile_run() + torch.cuda.synchronize() self._assert_memory_footprint_increased_during_profiling() - # Get the peak memory allocation recorded by torch - peak_memory = torch.cuda.memory_stats()["allocated_bytes.all.peak"] - - # Check for any memory left around that may have been allocated on the - # gpu outside of `torch`. NCCL operations, for example, can use a few - # GB during a forward pass - torch.cuda.empty_cache() - torch_allocated_bytes = torch.cuda.memory_stats( - )["allocated_bytes.all.current"] - total_allocated_bytes = torch.cuda.mem_get_info( - )[1] - torch.cuda.mem_get_info()[0] - non_torch_allocations = total_allocated_bytes - torch_allocated_bytes - if non_torch_allocations > 0: - peak_memory += non_torch_allocations - - available_kv_cache_memory = ( - total_gpu_memory * self.cache_config.gpu_memory_utilization - - peak_memory) + memory_for_current_instance = total_gpu_memory * \ + self.cache_config.gpu_memory_utilization + available_kv_cache_memory = (memory_for_current_instance - + result.non_kv_cache_memory_in_bytes) # Calculate the number of blocks that can be allocated with the # profiled peak memory. @@ -233,24 +222,23 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: num_gpu_blocks = max(num_gpu_blocks, 0) num_cpu_blocks = max(num_cpu_blocks, 0) - end_time = time.time() - logger.info( - "Memory profiling results: " - "duration=%.2f seconds, " - "total_gpu_memory=%.2fGiB, " - "initial_memory_usage=%.2fGiB, " - "peak_torch_memory=%.2fGiB, " - "memory_usage_post_profile=%.2fGiB, " - "non_torch_memory=%.2fGiB, " - "kv_cache_size=%.2fGiB, " - "gpu_memory_utilization=%.2f.", end_time - start_time, - total_gpu_memory / (1024**3), - (total_gpu_memory - free_memory_pre_profile) / (1024**3), - (peak_memory - non_torch_allocations) / (1024**3), - total_allocated_bytes / (1024**3), - non_torch_allocations / (1024**3), - available_kv_cache_memory / (1024**3), - self.cache_config.gpu_memory_utilization) + msg = (f"Memory profiling takes {result.profile_time:.2f} seconds\n" + "the current vLLM instance can use " + "total_gpu_memory " + f"({(total_gpu_memory / GiB_bytes):.2f}GiB)" + " x gpu_memory_utilization " + f"({self.cache_config.gpu_memory_utilization:.2f})" + f" = {(memory_for_current_instance / GiB_bytes):.2f}GiB\n" + "model weights take " + f"{(result.weights_memory_in_bytes / GiB_bytes):.2f}GiB;" + " non_torch_memory takes " + f"{(result.non_torch_increase_in_bytes / GiB_bytes):.2f}GiB;" + " PyTorch activation peak memory takes " + f"{(result.torch_peak_increase_in_bytes / GiB_bytes):.2f}GiB;" + " the rest of the memory reserved for KV Cache is " + f"{(available_kv_cache_memory / GiB_bytes):.2f}GiB.") + + logger.info(msg) # Final cleanup if self.model_runner.lora_manager: From 35ffa682b1cd3f47eb6cda586a16dab5c0401477 Mon Sep 17 00:00:00 2001 From: bk-TurbaAI Date: Mon, 16 Dec 2024 23:20:39 +0100 Subject: [PATCH 003/200] [Docs] hint to enable use of GPU performance counters in profiling tools for multi-node distributed serving (#11235) Co-authored-by: Michael Goin --- docs/source/serving/distributed_serving.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/serving/distributed_serving.rst b/docs/source/serving/distributed_serving.rst index 4d57206e53a05..b24ba53e59694 100644 --- a/docs/source/serving/distributed_serving.rst +++ b/docs/source/serving/distributed_serving.rst @@ -54,7 +54,7 @@ Multi-Node Inference and Serving If a single node does not have enough GPUs to hold the model, you can run the model using multiple nodes. It is important to make sure the execution environment is the same on all nodes, including the model path, the Python environment. The recommended way is to use docker images to ensure the same environment, and hide the heterogeneity of the host machines via mapping them into the same docker configuration. -The first step, is to start containers and organize them into a cluster. We have provided a helper `script `_ to start the cluster. +The first step, is to start containers and organize them into a cluster. We have provided a helper `script `_ to start the cluster. Please note, this script launches docker without administrative privileges that would be required to access GPU performance counters when running profiling and tracing tools. For that purpose, the script can have ``CAP_SYS_ADMIN`` to the docker container by using the ``--cap-add`` option in the docker run command. Pick a node as the head node, and run the following command: From c301616ed23fef433db1a49df332b9d61d3178ad Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 16 Dec 2024 15:53:18 -0800 Subject: [PATCH 004/200] [ci][tests] add gh200 tests (#11244) Signed-off-by: youkaichao --- .buildkite/run-gh200-test.sh | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 .buildkite/run-gh200-test.sh diff --git a/.buildkite/run-gh200-test.sh b/.buildkite/run-gh200-test.sh new file mode 100644 index 0000000000000..d25510c47fe6b --- /dev/null +++ b/.buildkite/run-gh200-test.sh @@ -0,0 +1,25 @@ +#!/bin/bash + +# This script build the GH200 docker image and run the offline inference inside the container. +# It serves a sanity check for compilation and basic model usage. +set -ex + +# Try building the docker image +DOCKER_BUILDKIT=1 docker build . \ + --target test \ + -platform "linux/arm64" \ + -t gh200-test \ + --build-arg max_jobs=66 \ + --build-arg nvcc_threads=2 \ + --build-arg torch_cuda_arch_list="9.0+PTX" \ + --build-arg vllm_fa_cmake_gpu_arches="90-real" + +# Setup cleanup +remove_docker_container() { docker rm -f gh200-test || true; } +trap remove_docker_container EXIT +remove_docker_container + +# Run the image and test offline inference +docker run --name gh200-test --gpus=all --entrypoint="" gh200-test bash -c ' + python3 examples/offline_inference.py +' From 88a412ed3d964de3443c42a6a35108115ee0ad25 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 16 Dec 2024 16:15:22 -0800 Subject: [PATCH 005/200] [torch.compile] fast inductor (#11108) Signed-off-by: youkaichao Co-authored-by: Tyler Michael Smith --- vllm/compilation/backends.py | 213 +++++++++++++++++- vllm/config.py | 415 ++++++++++++++++++++++++++++++++++- vllm/envs.py | 3 + 3 files changed, 624 insertions(+), 7 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 4a5dc337d01b8..0c7bbfe599b02 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -1,6 +1,10 @@ +import ast import copy import dataclasses +import os +import pprint import time +from collections import defaultdict from contextlib import ExitStack from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple from unittest.mock import patch @@ -21,6 +25,122 @@ logger = init_logger(__name__) +class InductorHashCache: + """ + Disk format: a Python list of tuples, each tuple is + (runtime_shape, graph_index, hash_str) + We use list of tuple for readability. + + In-memory format: a defaultdict of dict, where the key is + runtime_shape, and the value is a dict of graph_index to hash_str. + + The data is essentially `Dict[Optional[int], Dict[int, str]]`, + we don't use json here because json doesn't support int as key. + + TODO: better off-the-shelf solution to serialize the data? + """ + + def __init__(self, cache_dir: str, disabled: bool = False): + self.cache: defaultdict = defaultdict(dict) + self.disabled = disabled + self.cache_dir = cache_dir + self.cache_file_path = os.path.join(cache_dir, + "inductor_hash_cache.py") + if disabled: + return + # set flags so that Inductor and Triton store their cache + # in the cache_dir, then users only need to copy the cache_dir + # to another machine to reuse the cache. + inductor_cache = os.path.join(cache_dir, "inductor_cache") + os.makedirs(inductor_cache, exist_ok=True) + os.environ["TORCHINDUCTOR_CACHE_DIR"] = inductor_cache + triton_cache = os.path.join(cache_dir, "triton_cache") + os.makedirs(triton_cache, exist_ok=True) + os.environ["TRITON_CACHE_DIR"] = triton_cache + if os.path.exists(self.cache_file_path): + with open(self.cache_file_path) as f: + self.deserialize(f.read()) + + def deserialize(self, data: str): + # we use ast.literal_eval to parse the data + # because it is a safe way to parse Python literals. + # do not use eval(), it is unsafe. + list_data = ast.literal_eval(data) + for runtime_shape, graph_index, hash_str in list_data: + self.cache[runtime_shape][graph_index] = hash_str + + def serialize(self) -> str: + data = [] + for runtime_shape, graph_index_to_hash_str in self.cache.items(): + for graph_index, hash_str in graph_index_to_hash_str.items(): + data.append((runtime_shape, graph_index, hash_str)) + printer = pprint.PrettyPrinter(indent=4) + return printer.pformat(data) + + def save_to_file(self): + if self.disabled: + return + with open(self.cache_file_path, "w") as f: + f.write(self.serialize()) + + def __contains__(self, key: Tuple[Optional[int], int]) -> bool: + if self.disabled: + return False + runtime_shape, graph_index = key + return runtime_shape in self.cache and graph_index in self.cache[ + runtime_shape] + + def __getitem__(self, key: Tuple[Optional[int], int]) -> str: + if self.disabled: + raise KeyError("cannot read from disabled cache") + runtime_shape, graph_index = key + return self.cache[runtime_shape][graph_index] + + def __setitem__(self, key: Tuple[Optional[int], int], value: str): + # setitem for disabled cache is fine, because we + # don't actually write to the disk + runtime_shape, graph_index = key + self.cache[runtime_shape][graph_index] = value + + +class AlwaysHitShapeEnv: + """ + Why do we need this class: + + For normal `torch.compile` usage, every compilation will have + one Dynamo bytecode compilation and one Inductor compilation. + The Inductor compilation happens under the context of the + Dynamo bytecode compilation, and that context is used to + determine the dynamic shape information, etc. + + For our use case, we only run Dynamo bytecode compilation once, + and run Inductor compilation multiple times with different shapes + plus a general shape. The compilation for specific shapes happens + outside of the context of the Dynamo bytecode compilation. At that + time, we don't have shape environment to provide to Inductor, and + it will fail the Inductor code cache lookup. + + By providing a dummy shape environment that always hits, we can + make the Inductor code cache lookup always hit, and we can + compile the graph for different shapes as needed. + + The following dummy methods are obtained by trial-and-error + until it works. + """ + + def __init__(self) -> None: + self.guards: List[Any] = [] + + def evaluate_guards_expression(self, *args, **kwargs): + return True + + def get_pruned_guards(self, *args, **kwargs): + return [] + + def produce_guards_expression(self, *args, **kwargs): + return "" + + def wrap_inductor(graph, example_inputs, additional_inductor_config, @@ -55,9 +175,93 @@ def wrap_inductor(graph, # inductor can inplace modify the graph, so we need to copy it # see https://github.com/pytorch/pytorch/issues/138980 graph = copy.deepcopy(graph) - compiled_graph = compile_fx(graph, - example_inputs, - config_patches=current_config) + + cache_data = compilation_config.inductor_hash_cache + if (runtime_shape, graph_index) in cache_data: + # we compiled this graph before + # so we can directly lookup the compiled graph via hash + hash_str = cache_data[(runtime_shape, graph_index)] + if graph_index == 0: + # adds some info logging for the first graph + logger.info( + "Directly lookup the graph for shape %s from the cache", + str(runtime_shape)) # noqa + logger.debug( + "directly lookup the %s-th graph for shape %s via hash %s", + graph_index, str(runtime_shape), hash_str) + from torch._inductor.codecache import FxGraphCache + with patch("torch._inductor.codecache.FxGraphCache._get_shape_env", + lambda *args, **kwargs: AlwaysHitShapeEnv()): + inductor_compiled_graph = FxGraphCache._lookup_graph( + hash_str, example_inputs, True, False) + assert inductor_compiled_graph is not None, ( + "Inductor cache lookup failed. Please remove" + f"the cache file {compilation_config.inductor_hash_cache.cache_file_path} and try again." # noqa + ) + + # Inductor calling convention (function signature): + # f(list) -> tuple + # Dynamo calling convention (function signature): + # f(*args) -> Any + + # need to know if the graph returns a tuple + from torch._inductor.compile_fx import graph_returns_tuple + returns_tuple = graph_returns_tuple(graph) + + # this is the graph we return to Dynamo to run + def compiled_graph(*args): + # convert args to list + list_args = list(args) + graph_output = inductor_compiled_graph(list_args) + # unpack the tuple if needed + if returns_tuple: + return graph_output + else: + return graph_output[0] + else: + # it's the first time we compile this graph + # the assumption is that we don't have nested Inductor compilation. + # compiled_fx_graph_hash will only be called once, and we can hook + # it to get the hash of the compiled graph directly. + from torch._inductor.codecache import compiled_fx_graph_hash + + def hijack_compiled_fx_graph_hash(*args, **kwargs): + out = compiled_fx_graph_hash(*args, **kwargs) + # store the hash in the cache + nonlocal cache_data + cache_data[(runtime_shape, graph_index)] = out[0] + if graph_index == 0: + # adds some info logging for the first graph + logger.info("Cache the graph of shape %s for later use", + str(runtime_shape)) + logger.debug("store the %s-th graph for shape %s via hash %s", + graph_index, str(runtime_shape), out[0]) + return out + + def _check_can_cache(*args, **kwargs): + # no error means it can be cached. + # Inductor refuses to cache the graph outside of Dynamo + # tracing context, and also disables caching for graphs + # with high-order ops. + # For vLLM, in either case, we want to cache the graph. + # see https://github.com/pytorch/pytorch/blob/9f5ebf3fc609105a74eab4ccc24932d6353ff566/torch/_inductor/codecache.py#L1221 # noqa + return + + def _get_shape_env(): + return AlwaysHitShapeEnv() + + with patch(# for hijacking the hash of the compiled graph + "torch._inductor.codecache.compiled_fx_graph_hash", + hijack_compiled_fx_graph_hash), \ + patch(# for providing a dummy shape environment + "torch._inductor.codecache.FxGraphCache._get_shape_env", + _get_shape_env), \ + patch(# for forcing the graph to be cached + "torch._inductor.codecache.FxGraphCache._check_can_cache", + _check_can_cache): + compiled_graph = compile_fx(graph, + example_inputs, + config_patches=current_config) # after compiling the last graph, record the end time if graph_index == num_graphs - 1: @@ -457,6 +661,9 @@ def __call__(self, *args) -> Any: # finished compilations for all required shapes if self.is_last_graph and not self.to_be_compiled_sizes: + + # save the hash of the inductor graph for the next run + self.compilation_config.inductor_hash_cache.save_to_file() end_monitoring_torch_compile(self.vllm_config) if not entry.use_cudagraph: diff --git a/vllm/config.py b/vllm/config.py index fce8011be4015..9cfd08024ea7b 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3,6 +3,7 @@ import enum import hashlib import json +import os import warnings from contextlib import contextmanager from dataclasses import dataclass, field, replace @@ -162,6 +163,30 @@ class ModelConfig: which allows no processors. """ + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + factors: List[Any] = [] + factors.append(self.model) + factors.append(self.dtype) + factors.append(self.quantization) + factors.append(self.quantization_param_path) + factors.append(self.revision) + factors.append(self.code_revision) + factors.append(self.trust_remote_code) + factors.append(self.rope_scaling) + factors.append(self.rope_theta) + return hashlib.sha256(str(factors).encode()).hexdigest() + def __init__(self, model: str, task: Union[TaskOption, Literal["draft"]], @@ -203,6 +228,8 @@ def __init__(self, self.seed = seed self.revision = revision self.code_revision = code_revision + self.rope_scaling = rope_scaling + self.rope_theta = rope_theta if hf_overrides is None: hf_overrides = {} @@ -832,6 +859,24 @@ class CacheConfig: cpu_offload_gb: Size of the CPU offload buffer in GiB. """ + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + factors: List[Any] = [] + factors.append(self.cache_dtype) + # `cpu_offload_gb` does not use `torch.compile` yet. + hash_str = hashlib.md5(str(factors).encode()).hexdigest() + return hash_str + def __init__( self, block_size: int, @@ -928,6 +973,24 @@ class TokenizerPoolConfig: pool_type: Union[str, Type["BaseTokenizerGroup"]] extra_config: dict + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # this config will not affect the computation graph. + factors: List[Any] = [] + hash_str = hashlib.md5(str(factors).encode()).hexdigest() + return hash_str + def __post_init__(self): if self.pool_type not in ("ray", ) and not isinstance( self.pool_type, type): @@ -1010,6 +1073,24 @@ class LoadConfig: default_factory=dict) ignore_patterns: Optional[Union[List[str], str]] = None + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # this config will not affect the computation graph. + factors: List[Any] = [] + hash_str = hashlib.md5(str(factors).encode()).hexdigest() + return hash_str + def __post_init__(self): model_loader_extra_config = self.model_loader_extra_config or {} if isinstance(model_loader_extra_config, str): @@ -1073,6 +1154,19 @@ class ParallelConfig: rank: int = 0 + def compute_hash(self): + """ + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + factors: List[Any] = [] + factors.append(self.pipeline_parallel_size) + factors.append(self.tensor_parallel_size) + return hashlib.sha256(str(factors).encode()).hexdigest() + def __post_init__(self) -> None: self.world_size = self.pipeline_parallel_size * \ self.tensor_parallel_size @@ -1209,6 +1303,24 @@ class SchedulerConfig: chunked_prefill_enabled: bool = field(init=False) + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # this config will not affect the computation graph. + factors: List[Any] = [] + hash_str = hashlib.md5(str(factors).encode()).hexdigest() + return hash_str + def __post_init__(self) -> None: if self.max_num_batched_tokens is None: if self.enable_chunked_prefill: @@ -1286,6 +1398,25 @@ class DeviceConfig: device: Optional[torch.device] device_type: str + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # the device/platform information will be summarized + # by torch/vllm automatically. + factors: List[Any] = [] + hash_str = hashlib.md5(str(factors).encode()).hexdigest() + return hash_str + def __init__(self, device: str = "auto") -> None: if device == "auto": # Automated device type detection @@ -1313,6 +1444,24 @@ class SpeculativeConfig: decoding with top-1 proposals. """ + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # spec decode does not use `torch.compile` yet. + factors: List[Any] = [] + hash_str = hashlib.md5(str(factors).encode()).hexdigest() + return hash_str + @staticmethod def maybe_create_spec_config( target_model_config: ModelConfig, @@ -1753,6 +1902,24 @@ class LoRAConfig: long_lora_scaling_factors: Optional[Tuple[float]] = None bias_enabled: bool = False + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # LoRA is not compatible with `torch.compile` . + factors: List[Any] = [] + hash_str = hashlib.md5(str(factors).encode()).hexdigest() + return hash_str + def __post_init__(self): # Setting the maximum rank to 256 should be able to satisfy the vast # majority of applications. @@ -1802,6 +1969,24 @@ class PromptAdapterConfig: max_cpu_prompt_adapters: Optional[int] = None prompt_adapter_dtype: Optional[torch.dtype] = None + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # this config will not affect the computation graph. + factors: List[Any] = [] + hash_str = hashlib.md5(str(factors).encode()).hexdigest() + return hash_str + def __post_init__(self): if self.max_prompt_adapters < 1: @@ -1830,6 +2015,24 @@ class MultiModalConfig: for each :class:`~vllm.multimodal.MultiModalPlugin`. """ + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # this config will not affect the computation graph. + factors: List[Any] = [] + hash_str = hashlib.md5(str(factors).encode()).hexdigest() + return hash_str + # TODO: Add configs to init vision tower or not. @@ -1869,6 +2072,24 @@ class PoolerConfig: ``math-shepherd-mistral-7b-prm`` model. """ + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # this config will not affect the computation graph. + factors: List[Any] = [] + hash_str = hashlib.md5(str(factors).encode()).hexdigest() + return hash_str + @staticmethod def from_json(json_str: str) -> "PoolerConfig": return PoolerConfig(**json.loads(json_str)) @@ -2103,6 +2324,24 @@ class DecodingConfig: # 'outlines' / 'lm-format-enforcer' / 'xgrammar' guided_decoding_backend: str = 'xgrammar' + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # this config will not affect the computation graph. + factors: List[Any] = [] + hash_str = hashlib.md5(str(factors).encode()).hexdigest() + return hash_str + def __post_init__(self): valid_guided_backends = ['outlines', 'lm-format-enforcer', 'xgrammar'] backend = self.guided_decoding_backend @@ -2124,6 +2363,24 @@ class ObservabilityConfig: # If set, collects the model execute time for the request. collect_model_execute_time: bool = False + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # this config will not affect the computation graph. + factors: List[Any] = [] + hash_str = hashlib.md5(str(factors).encode()).hexdigest() + return hash_str + def __post_init__(self): if not is_otel_available() and self.otlp_traces_endpoint is not None: raise ValueError( @@ -2165,6 +2422,24 @@ class KVTransferConfig(BaseModel): # The KV connector port, used to build distributed connection kv_port: int = 14579 + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # this config will not affect the computation graph. + factors: List[Any] = [] + hash_str = hashlib.md5(str(factors).encode()).hexdigest() + return hash_str + @classmethod def from_cli(cls, cli_value: str) -> "KVTransferConfig": """Parse the CLI value for the kv cache transfer config.""" @@ -2234,6 +2509,9 @@ class CompilationConfig(BaseModel): - 2: dynamo once. - 3: piecewise compilation. - debug_dump_path: the path to dump the debug information. + - cache_dir: the directory to store the compiled graph, to + accelerate Inductor compilation. By default, it will use + model-related information to generate a cache directory. - backend: the backend for compilation. It needs to be a string. - "" (empty string): use the default backend. - "eager"/"openxla"/...: use the specified backend registered in PyTorch. @@ -2302,12 +2580,10 @@ class CompilationConfig(BaseModel): """ # noqa level: int = 0 debug_dump_path: str = "" + cache_dir: str = "" backend: str = "" custom_ops: List[str] = Field(default_factory=list) - splitting_ops: List[str] = Field(default_factory=lambda: [ - "vllm.unified_attention", - "vllm.unified_attention_with_output", - ]) + splitting_ops: List[str] = Field(default=None) # type: ignore use_inductor: bool = True candidate_compile_sizes: Optional[List[int]] = Field(default=None) @@ -2371,12 +2647,37 @@ def model_post_init(self, __context: Any) -> None: enabled_custom_ops: Counter[str] = PrivateAttr disabled_custom_ops: Counter[str] = PrivateAttr compilation_time: float = PrivateAttr + # should be InductorHashCache, but Pydantic does not support it + inductor_hash_cache: Any = PrivateAttr # Per-model forward context # Mainly used to store attention cls # Map from layer name to the attention cls static_forward_context: Dict[str, Any] = PrivateAttr + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + factors: List[Any] = [] + factors.append(self.level) + factors.append(self.backend) + factors.append(self.custom_ops) + factors.append(self.splitting_ops) + factors.append(self.use_inductor) + factors.append(self.inductor_compile_config) + factors.append(self.inductor_passes) + factors.append(self.pass_config.uuid()) + return hashlib.sha256(str(factors).encode()).hexdigest() + def __repr__(self) -> str: exclude = { "static_forward_context", @@ -2405,6 +2706,27 @@ def model_post_init(self, __context: Any) -> None: count_all = self.custom_ops.count("all") assert count_none + count_all <= 1, "Can only specify 'none' or 'all'" + if self.splitting_ops is None: + if envs.VLLM_USE_V1: + # v1 must split the graph on attention ops + # for piecewise cudagraph + self.splitting_ops = [ + "vllm.unified_attention", + "vllm.unified_attention_with_output", + ] + else: + # v0 can use full graph compilation without splitting, + # splitting is optional. + # right now we still need it. kv cache shape + # will be included in the graph if we don't split + # the graph. + # TODO: hide kv cache in static forward context + # so that inductor does not see it. + self.splitting_ops = [ + "vllm.unified_attention", + "vllm.unified_attention_with_output", + ] + for k, v in self.inductor_passes.items(): if not isinstance(v, str): assert callable(v), ( @@ -2444,6 +2766,30 @@ def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]: # TODO: pass user-specified backend to piecewise compilation # merge with the config use_inductor assert self.level == CompilationLevel.PIECEWISE + + if not self.cache_dir: + # no provided cache dir, generate one based on the known factors + # that affects the compilation. if none of the factors change, + # the cache dir will be the same so that we can reuse the compiled + # graph. + hash_key = vllm_config.compute_hash() + cache_dir = os.path.join( + envs.VLLM_CACHE_ROOT, "torch_compile_cache", hash_key, + f"rank_{vllm_config.parallel_config.rank}") + os.makedirs(cache_dir, exist_ok=True) + self.cache_dir = cache_dir + + disabled = envs.VLLM_DISABLE_COMPILE_CACHE + from vllm.compilation.backends import InductorHashCache + self.inductor_hash_cache: InductorHashCache = InductorHashCache( + self.cache_dir, disabled=disabled) + if disabled: + logger.info("vLLM's torch.compile cache is disabled.") + else: + logger.info( + "Using cache directory: %s for vLLM's torch.compile", + self.cache_dir) + from vllm.compilation.backends import VllmBackend return VllmBackend(vllm_config) @@ -2520,6 +2866,67 @@ class VllmConfig: init=True) # type: ignore instance_id: str = "" + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + factors: List[Any] = [] + # summarize system state + from torch._inductor.codecache import CacheBase + system_factors = CacheBase.get_system() + factors.append(system_factors) + + # summarize pytorch state + from torch._inductor.codecache import torch_key + torch_factors = torch_key() + factors.append(torch_factors) + + # summarize vllm config + vllm_factors: List[Any] = [] + from vllm import __version__ + vllm_factors.append(__version__) + if self.model_config: + vllm_factors.append(self.model_config.compute_hash()) + if self.cache_config: + vllm_factors.append(self.cache_config.compute_hash()) + if self.parallel_config: + vllm_factors.append(self.parallel_config.compute_hash()) + if self.scheduler_config: + vllm_factors.append(self.scheduler_config.compute_hash()) + if self.device_config: + vllm_factors.append(self.device_config.compute_hash()) + if self.load_config: + vllm_factors.append(self.load_config.compute_hash()) + if self.lora_config: + vllm_factors.append(self.lora_config.compute_hash()) + if self.speculative_config: + vllm_factors.append(self.speculative_config.compute_hash()) + if self.decoding_config: + vllm_factors.append(self.decoding_config.compute_hash()) + if self.observability_config: + vllm_factors.append(self.observability_config.compute_hash()) + if self.prompt_adapter_config: + vllm_factors.append(self.prompt_adapter_config.compute_hash()) + if self.quant_config: + pass # should be captured by model_config.quantization + if self.compilation_config: + vllm_factors.append(self.compilation_config.compute_hash()) + if self.kv_transfer_config: + vllm_factors.append(self.kv_transfer_config.compute_hash()) + + factors.append(vllm_factors) + + hash_str = hashlib.md5(str(factors).encode()).hexdigest()[:10] + return hash_str + def pad_for_cudagraph(self, batch_size: int) -> int: # if batch_size > self.compilation_config.max_capture_size, # it should raise an IndexError. diff --git a/vllm/envs.py b/vllm/envs.py index da17b747ea215..18870c1c6b51a 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -71,6 +71,7 @@ VLLM_USE_V1: bool = False VLLM_ENABLE_V1_MULTIPROCESSING: bool = True VLLM_LOG_BATCHSIZE_INTERVAL: float = -1 + VLLM_DISABLE_COMPILE_CACHE: bool = False def get_default_cache_root(): @@ -463,6 +464,8 @@ def get_default_config_root(): lambda: bool(int(os.getenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1"))), "VLLM_LOG_BATCHSIZE_INTERVAL": lambda: float(os.getenv("VLLM_LOG_BATCHSIZE_INTERVAL", "-1")), + "VLLM_DISABLE_COMPILE_CACHE": + lambda: bool(int(os.getenv("VLLM_DISABLE_COMPILE_CACHE", "0"))), } # end-env-vars-definition From 35bae114a89e03e3dc6a6d2f758378e58938bffa Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 16 Dec 2024 17:22:38 -0800 Subject: [PATCH 006/200] fix gh200 tests on main (#11246) Signed-off-by: youkaichao --- .buildkite/run-gh200-test.sh | 4 ++-- docs/source/serving/deploying_with_docker.rst | 5 +---- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/.buildkite/run-gh200-test.sh b/.buildkite/run-gh200-test.sh index d25510c47fe6b..d06604f96f2b8 100644 --- a/.buildkite/run-gh200-test.sh +++ b/.buildkite/run-gh200-test.sh @@ -6,8 +6,8 @@ set -ex # Try building the docker image DOCKER_BUILDKIT=1 docker build . \ - --target test \ - -platform "linux/arm64" \ + --target vllm-openai \ + --platform "linux/arm64" \ -t gh200-test \ --build-arg max_jobs=66 \ --build-arg nvcc_threads=2 \ diff --git a/docs/source/serving/deploying_with_docker.rst b/docs/source/serving/deploying_with_docker.rst index 11a9f12fd17cd..56f0020a1011a 100644 --- a/docs/source/serving/deploying_with_docker.rst +++ b/docs/source/serving/deploying_with_docker.rst @@ -54,16 +54,13 @@ of PyTorch Nightly and should be considered **experimental**. Using the flag `-- # Example of building on Nvidia GH200 server. (Memory usage: ~12GB, Build time: ~1475s / ~25 min, Image size: 7.26GB) $ DOCKER_BUILDKIT=1 sudo docker build . \ --target vllm-openai \ - -platform "linux/arm64" \ + --platform "linux/arm64" \ -t vllm/vllm-gh200-openai:latest \ --build-arg max_jobs=66 \ --build-arg nvcc_threads=2 \ --build-arg torch_cuda_arch_list="9.0+PTX" \ --build-arg vllm_fa_cmake_gpu_arches="90-real" - - - To run vLLM: .. code-block:: console From 0064f697d318a2ce38342f7c20754cf229311b8b Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Mon, 16 Dec 2024 22:39:58 -0500 Subject: [PATCH 007/200] [CI] Add test case with JSON schema using references + use xgrammar by default with OpenAI parse (#10935) Signed-off-by: mgoin --- tests/entrypoints/conftest.py | 39 +++++++++++++++++++ tests/entrypoints/llm/test_guided_generate.py | 28 +++++++++++++ vllm/entrypoints/openai/protocol.py | 2 +- 3 files changed, 68 insertions(+), 1 deletion(-) diff --git a/tests/entrypoints/conftest.py b/tests/entrypoints/conftest.py index 0f7d15e1d85aa..ef74062ce4b41 100644 --- a/tests/entrypoints/conftest.py +++ b/tests/entrypoints/conftest.py @@ -100,6 +100,45 @@ def sample_complex_json_schema(): } +@pytest.fixture +def sample_definition_json_schema(): + return { + '$defs': { + 'Step': { + 'properties': { + 'explanation': { + 'title': 'Explanation', + 'type': 'string' + }, + 'output': { + 'title': 'Output', + 'type': 'string' + } + }, + 'required': ['explanation', 'output'], + 'title': 'Step', + 'type': 'object' + } + }, + 'properties': { + 'steps': { + 'items': { + '$ref': '#/$defs/Step' + }, + 'title': 'Steps', + 'type': 'array' + }, + 'final_answer': { + 'title': 'Final Answer', + 'type': 'string' + } + }, + 'required': ['steps', 'final_answer'], + 'title': 'MathReasoning', + 'type': 'object' + } + + @pytest.fixture def sample_guided_choice(): return [ diff --git a/tests/entrypoints/llm/test_guided_generate.py b/tests/entrypoints/llm/test_guided_generate.py index de6257cfc551c..ed50ec6bbc9eb 100644 --- a/tests/entrypoints/llm/test_guided_generate.py +++ b/tests/entrypoints/llm/test_guided_generate.py @@ -104,6 +104,34 @@ def test_guided_complex_json_completion(sample_complex_json_schema, llm): schema=sample_complex_json_schema) +@pytest.mark.skip_global_cleanup +def test_guided_definition_json_completion(sample_definition_json_schema, llm): + sampling_params = SamplingParams(temperature=1.0, + max_tokens=1000, + guided_decoding=GuidedDecodingParams( + json=sample_definition_json_schema)) + outputs = llm.generate(prompts=[ + f"Give an example JSON for solving 8x + 7 = -23 " + f"that fits this schema: {sample_definition_json_schema}" + ] * 2, + sampling_params=sampling_params, + use_tqdm=True) + + assert outputs is not None + + for output in outputs: + assert output is not None + assert isinstance(output, RequestOutput) + prompt = output.prompt + + generated_text = output.outputs[0].text + assert generated_text is not None + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + output_json = json.loads(generated_text) + jsonschema.validate(instance=output_json, + schema=sample_definition_json_schema) + + @pytest.mark.skip_global_cleanup def test_guided_choice_completion(sample_guided_choice, llm): sampling_params = SamplingParams( diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 6ed7c2e9dcd6b..5a70e0952666b 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -387,7 +387,7 @@ def to_sampling_params( assert json_schema is not None self.guided_json = json_schema.json_schema if self.guided_decoding_backend is None: - self.guided_decoding_backend = "lm-format-enforcer" + self.guided_decoding_backend = "xgrammar" guided_decoding = GuidedDecodingParams.from_optional( json=self._get_guided_json_from_tool() or self.guided_json, From 66d4b16724226e9f377551198cc7425c12ddafae Mon Sep 17 00:00:00 2001 From: kYLe Date: Tue, 17 Dec 2024 00:09:58 -0600 Subject: [PATCH 008/200] [Frontend] Add OpenAI API support for input_audio (#11027) Signed-off-by: DarkLight1337 Co-authored-by: DarkLight1337 --- .../serving/openai_compatible_server.md | 10 +- docs/source/usage/multimodal_inputs.rst | 90 ++++++++++++- ...i_chat_completion_client_for_multimodal.py | 34 ++++- tests/entrypoints/openai/test_audio.py | 125 +++++++++++++++++- vllm/entrypoints/chat_utils.py | 65 +++++++-- 5 files changed, 301 insertions(+), 23 deletions(-) diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index 14a5b02d72aa5..1bc8d32d2d161 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -34,11 +34,6 @@ We currently support the following OpenAI APIs: - *Note: `suffix` parameter is not supported.* - [Chat Completions API](#chat-api) (`/v1/chat/completions`) - Only applicable to [text generation models](../models/generative_models.rst) (`--task generate`) with a [chat template](#chat-template). - - [Vision](https://platform.openai.com/docs/guides/vision)-related parameters are supported; see [Multimodal Inputs](../usage/multimodal_inputs.rst). - - *Note: `image_url.detail` parameter is not supported.* - - We also support `audio_url` content type for audio files. - - Refer to [vllm.entrypoints.chat_utils](https://github.com/vllm-project/vllm/tree/main/vllm/entrypoints/chat_utils.py) for the exact schema. - - *TODO: Support `input_audio` content type as defined [here](https://github.com/openai/openai-python/blob/v1.52.2/src/openai/types/chat/chat_completion_content_part_input_audio_param.py).* - *Note: `parallel_tool_calls` and `user` parameters are ignored.* - [Embeddings API](#embeddings-api) (`/v1/embeddings`) - Only applicable to [embedding models](../models/pooling_models.rst) (`--task embed`). @@ -209,6 +204,11 @@ The following extra parameters are supported: Refer to [OpenAI's API reference](https://platform.openai.com/docs/api-reference/chat) for more details. +We support both [Vision](https://platform.openai.com/docs/guides/vision)- and +[Audio](https://platform.openai.com/docs/guides/audio?audio-generation-quickstart-example=audio-in)-related parameters; +see our [Multimodal Inputs](../usage/multimodal_inputs.rst) guide for more information. +- *Note: `image_url.detail` parameter is not supported.* + #### Extra parameters The following [sampling parameters (click through to see documentation)](../dev/sampling_params.rst) are supported. diff --git a/docs/source/usage/multimodal_inputs.rst b/docs/source/usage/multimodal_inputs.rst index 1e00f26f9a3ba..680382e457cc5 100644 --- a/docs/source/usage/multimodal_inputs.rst +++ b/docs/source/usage/multimodal_inputs.rst @@ -315,7 +315,95 @@ You can use `these tests `_. +Here is a simple example using Ultravox-v0.3. + +First, launch the OpenAI-compatible server: + +.. code-block:: bash + + vllm serve fixie-ai/ultravox-v0_3 + +Then, you can use the OpenAI client as follows: + +.. code-block:: python + + import base64 + import requests + from openai import OpenAI + from vllm.assets.audio import AudioAsset + + def encode_base64_content_from_url(content_url: str) -> str: + """Encode a content retrieved from a remote url to base64 format.""" + + with requests.get(content_url) as response: + response.raise_for_status() + result = base64.b64encode(response.content).decode('utf-8') + + return result + + openai_api_key = "EMPTY" + openai_api_base = "http://localhost:8000/v1" + + client = OpenAI( + api_key=openai_api_key, + base_url=openai_api_base, + ) + + # Any format supported by librosa is supported + audio_url = AudioAsset("winning_call").url + audio_base64 = encode_base64_content_from_url(audio_url) + + chat_completion_from_base64 = client.chat.completions.create( + messages=[{ + "role": "user", + "content": [ + { + "type": "text", + "text": "What's in this audio?" + }, + { + "type": "input_audio", + "input_audio": { + "data": audio_base64, + "format": "wav" + }, + }, + ], + }], + model=model, + max_completion_tokens=64, + ) + + result = chat_completion_from_base64.choices[0].message.content + print("Chat completion output from input audio:", result) + +Alternatively, you can pass :code:`audio_url`, which is the audio counterpart of :code:`image_url` for image input: + +.. code-block:: python + + chat_completion_from_url = client.chat.completions.create( + messages=[{ + "role": "user", + "content": [ + { + "type": "text", + "text": "What's in this audio?" + }, + { + "type": "audio_url", + "audio_url": { + "url": audio_url + }, + }, + ], + }], + model=model, + max_completion_tokens=64, + ) + + result = chat_completion_from_url.choices[0].message.content + print("Chat completion output from audio url:", result) A full code example can be found in `examples/openai_chat_completion_client_for_multimodal.py `_. diff --git a/examples/openai_chat_completion_client_for_multimodal.py b/examples/openai_chat_completion_client_for_multimodal.py index 0ec4f71dddf93..6a160fd70423f 100644 --- a/examples/openai_chat_completion_client_for_multimodal.py +++ b/examples/openai_chat_completion_client_for_multimodal.py @@ -153,10 +153,37 @@ def run_multi_image() -> None: # Audio input inference def run_audio() -> None: - # Any format supported by librosa is supported audio_url = AudioAsset("winning_call").url + audio_base64 = encode_base64_content_from_url(audio_url) + + # OpenAI-compatible schema (`input_audio`) + chat_completion_from_base64 = client.chat.completions.create( + messages=[{ + "role": + "user", + "content": [ + { + "type": "text", + "text": "What's in this audio?" + }, + { + "type": "input_audio", + "input_audio": { + # Any format supported by librosa is supported + "data": audio_base64, + "format": "wav" + }, + }, + ], + }], + model=model, + max_completion_tokens=64, + ) + + result = chat_completion_from_base64.choices[0].message.content + print("Chat completion output from input audio:", result) - # Use audio url in the payload + # HTTP URL chat_completion_from_url = client.chat.completions.create( messages=[{ "role": @@ -169,6 +196,7 @@ def run_audio() -> None: { "type": "audio_url", "audio_url": { + # Any format supported by librosa is supported "url": audio_url }, }, @@ -181,7 +209,7 @@ def run_audio() -> None: result = chat_completion_from_url.choices[0].message.content print("Chat completion output from audio url:", result) - audio_base64 = encode_base64_content_from_url(audio_url) + # base64 URL chat_completion_from_base64 = client.chat.completions.create( messages=[{ "role": diff --git a/tests/entrypoints/openai/test_audio.py b/tests/entrypoints/openai/test_audio.py index b579dcbb5c402..0a29d77e73abc 100644 --- a/tests/entrypoints/openai/test_audio.py +++ b/tests/entrypoints/openai/test_audio.py @@ -155,6 +155,61 @@ async def test_single_chat_session_audio_base64encoded( assert message.content is not None and len(message.content) >= 0 +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS) +async def test_single_chat_session_input_audio( + client: openai.AsyncOpenAI, model_name: str, audio_url: str, + base64_encoded_audio: Dict[str, str]): + messages = [{ + "role": + "user", + "content": [ + { + "type": "input_audio", + "input_audio": { + "data": base64_encoded_audio[audio_url], + "format": "wav" + } + }, + { + "type": "text", + "text": "What's happening in this audio?" + }, + ], + }] + + # test single completion + chat_completion = await client.chat.completions.create( + model=model_name, + messages=messages, + max_completion_tokens=10, + logprobs=True, + top_logprobs=5) + assert len(chat_completion.choices) == 1 + + choice = chat_completion.choices[0] + assert choice.finish_reason == "length" + assert chat_completion.usage == openai.types.CompletionUsage( + completion_tokens=10, prompt_tokens=202, total_tokens=212) + + message = choice.message + message = chat_completion.choices[0].message + assert message.content is not None and len(message.content) >= 10 + assert message.role == "assistant" + messages.append({"role": "assistant", "content": message.content}) + + # test multi-turn dialogue + messages.append({"role": "user", "content": "express your result in json"}) + chat_completion = await client.chat.completions.create( + model=model_name, + messages=messages, + max_completion_tokens=10, + ) + message = chat_completion.choices[0].message + assert message.content is not None and len(message.content) >= 0 + + @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS) @@ -212,11 +267,72 @@ async def test_chat_streaming_audio(client: openai.AsyncOpenAI, assert "".join(chunks) == output +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS) +async def test_chat_streaming_input_audio(client: openai.AsyncOpenAI, + model_name: str, audio_url: str, + base64_encoded_audio: Dict[str, + str]): + messages = [{ + "role": + "user", + "content": [ + { + "type": "input_audio", + "input_audio": { + "data": base64_encoded_audio[audio_url], + "format": "wav" + } + }, + { + "type": "text", + "text": "What's happening in this audio?" + }, + ], + }] + + # test single completion + chat_completion = await client.chat.completions.create( + model=model_name, + messages=messages, + max_completion_tokens=10, + temperature=0.0, + ) + output = chat_completion.choices[0].message.content + stop_reason = chat_completion.choices[0].finish_reason + + # test streaming + stream = await client.chat.completions.create( + model=model_name, + messages=messages, + max_completion_tokens=10, + temperature=0.0, + stream=True, + ) + chunks: List[str] = [] + finish_reason_count = 0 + async for chunk in stream: + delta = chunk.choices[0].delta + if delta.role: + assert delta.role == "assistant" + if delta.content: + chunks.append(delta.content) + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + # finish reason should only return in last block + assert finish_reason_count == 1 + assert chunk.choices[0].finish_reason == stop_reason + assert delta.content + assert "".join(chunks) == output + + @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS) async def test_multi_audio_input(client: openai.AsyncOpenAI, model_name: str, - audio_url: str): + audio_url: str, + base64_encoded_audio: Dict[str, str]): messages = [{ "role": @@ -229,9 +345,10 @@ async def test_multi_audio_input(client: openai.AsyncOpenAI, model_name: str, } }, { - "type": "audio_url", - "audio_url": { - "url": audio_url + "type": "input_audio", + "input_audio": { + "data": base64_encoded_audio[audio_url], + "format": "wav" } }, { diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index aaa5cd759366a..3df08c740d65b 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -13,7 +13,8 @@ # yapf conflicts with isort for this block # yapf: disable from openai.types.chat import (ChatCompletionAssistantMessageParam, - ChatCompletionContentPartImageParam) + ChatCompletionContentPartImageParam, + ChatCompletionContentPartInputAudioParam) from openai.types.chat import ( ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam) from openai.types.chat import (ChatCompletionContentPartRefusalParam, @@ -105,6 +106,7 @@ class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False): ChatCompletionContentPartParam: TypeAlias = Union[ OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam, + ChatCompletionContentPartInputAudioParam, ChatCompletionContentPartVideoParam, ChatCompletionContentPartRefusalParam, CustomChatCompletionContentSimpleImageParam, CustomChatCompletionContentSimpleAudioParam, @@ -519,6 +521,10 @@ def parse_image(self, image_url: str) -> None: def parse_audio(self, audio_url: str) -> None: raise NotImplementedError + @abstractmethod + def parse_input_audio(self, input_audio: Dict[str, str]) -> None: + raise NotImplementedError + @abstractmethod def parse_video(self, video_url: str) -> None: raise NotImplementedError @@ -545,6 +551,15 @@ def parse_audio(self, audio_url: str) -> None: placeholder = self._tracker.add("audio", audio) self._add_placeholder(placeholder) + def parse_input_audio(self, input_audio: Dict[str, str]) -> None: + input_audio_data = input_audio.get("data","") + input_audio_format = input_audio.get("format","") + audio_url = f"data:audio/{input_audio_format};base64,{input_audio_data}" + audio = get_and_parse_audio(audio_url) + + placeholder = self._tracker.add("audio", audio) + self._add_placeholder(placeholder) + def parse_video(self, video_url: str) -> None: video = get_and_parse_video(video_url) @@ -574,6 +589,15 @@ def parse_audio(self, audio_url: str) -> None: placeholder = self._tracker.add("audio", audio_coro) self._add_placeholder(placeholder) + def parse_input_audio(self, input_audio: Dict[str, str]) -> None: + input_audio_data = input_audio.get("data","") + input_audio_format = input_audio.get("format","") + audio_url = f"data:audio/{input_audio_format};base64,{input_audio_data}" + audio_coro = async_get_and_parse_audio(audio_url) + + placeholder = self._tracker.add("audio", audio_coro) + self._add_placeholder(placeholder) + def parse_video(self, video_url: str) -> None: video = async_get_and_parse_video(video_url) @@ -667,17 +691,22 @@ def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int], _TextParser = partial(cast, ChatCompletionContentPartTextParam) _ImageParser = partial(cast, ChatCompletionContentPartImageParam) _AudioParser = partial(cast, ChatCompletionContentPartAudioParam) +_InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam) _RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam) _VideoParser = partial(cast, ChatCompletionContentPartVideoParam) # Define a mapping from part types to their corresponding parsing functions. -MM_PARSER_MAP: Dict[str, Callable[[ChatCompletionContentPartParam], str]] = { +MM_PARSER_MAP: Dict[str, + Callable[[ChatCompletionContentPartParam], + Union[str, Dict[str,str]]]] = { "text": lambda part: _TextParser(part).get("text", ""), "image_url": lambda part: _ImageParser(part).get("image_url", {}).get("url", ""), "audio_url": lambda part: _AudioParser(part).get("audio_url", {}).get("url", ""), + "input_audio": + lambda part: _InputAudioParser(part).get("input_audio", {}), "refusal": lambda part: _RefusalParser(part).get("refusal", ""), "video_url": @@ -686,7 +715,8 @@ def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int], def _parse_chat_message_content_mm_part( - part: ChatCompletionContentPartParam) -> Tuple[str, str]: + part: ChatCompletionContentPartParam) -> Tuple[str, + Union[str, Dict[str, str]]]: """ Parses a given multi-modal content part based on its type. @@ -717,6 +747,7 @@ def _parse_chat_message_content_mm_part( return part_type, content # Handle missing 'type' but provided direct URL fields. + # 'type' is required field by pydantic if part_type is None: if part.get("image_url") is not None: image_params = cast(CustomChatCompletionContentSimpleImageParam, @@ -726,6 +757,9 @@ def _parse_chat_message_content_mm_part( audio_params = cast(CustomChatCompletionContentSimpleAudioParam, part) return "audio_url", audio_params.get("audio_url", "") + if part.get("input_audio") is not None: + input_audio_params = cast(Dict[str, str], part) + return "input_audio", input_audio_params if part.get("video_url") is not None: video_params = cast(CustomChatCompletionContentSimpleVideoParam, part) @@ -739,7 +773,7 @@ def _parse_chat_message_content_mm_part( VALID_MESSAGE_CONTENT_MM_PART_TYPES = ("text", "refusal", "image_url", - "audio_url", "video_url") + "audio_url", "input_audio", "video_url") def _parse_chat_message_content_parts( @@ -795,7 +829,7 @@ def _parse_chat_message_content_part( # Handle structured dictionary parts part_type, content = _parse_chat_message_content_mm_part(part) - # if part_type is text/refusal/image_url/audio_url/video_url but + # if part_type is text/refusal/image_url/audio_url/video_url/input_audio but # content is empty, log a warning and skip if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and not content: logger.warning( @@ -804,18 +838,30 @@ def _parse_chat_message_content_part( return None if part_type in ("text", "refusal"): - return {'type': 'text', 'text': content} if wrap_dicts else content + str_content = cast(str, content) + if wrap_dicts: + return {'type': 'text', 'text': str_content} + else: + return str_content if part_type == "image_url": - mm_parser.parse_image(content) + str_content = cast(str, content) + mm_parser.parse_image(str_content) return {'type': 'image'} if wrap_dicts else None if part_type == "audio_url": - mm_parser.parse_audio(content) + str_content = cast(str, content) + mm_parser.parse_audio(str_content) + return {'type': 'audio'} if wrap_dicts else None + + if part_type == "input_audio": + dict_content = cast(Dict[str, str], content) + mm_parser.parse_input_audio(dict_content) return {'type': 'audio'} if wrap_dicts else None if part_type == "video_url": - mm_parser.parse_video(content) + str_content = cast(str, content) + mm_parser.parse_video(str_content) return {'type': 'video'} if wrap_dicts else None raise NotImplementedError(f"Unknown part type: {part_type}") @@ -840,7 +886,6 @@ def _parse_chat_message_content( content = [ ChatCompletionContentPartTextParam(type="text", text=content) ] - result = _parse_chat_message_content_parts( role, content, # type: ignore From 59c9b6ebeba79b2d744eec86734a7e13b03dcab7 Mon Sep 17 00:00:00 2001 From: Roger Wang <136131678+ywang96@users.noreply.github.com> Date: Mon, 16 Dec 2024 22:10:57 -0800 Subject: [PATCH 009/200] [V1][VLM] Proper memory profiling for image language models (#11210) Signed-off-by: Roger Wang Co-authored-by: ywang96 --- vllm/config.py | 8 ++++ vllm/model_executor/models/pixtral.py | 5 ++ vllm/multimodal/registry.py | 23 +++++++-- vllm/v1/core/scheduler.py | 7 ++- vllm/v1/engine/mm_input_mapper.py | 1 + vllm/v1/worker/gpu_model_runner.py | 67 ++++++++++++++++++++++++--- 6 files changed, 98 insertions(+), 13 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 9cfd08024ea7b..9ecd3e72afa9f 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1280,6 +1280,14 @@ class SchedulerConfig: is_multimodal_model: bool = False + # FIXME(woosuk & ywang96): Below are placeholder values. We need to + # calculate the actual values from the configurations. + # Multimodal encoder run compute budget, only used in V1 + max_num_encoder_input_tokens = 16384 + + # Multimodal encoder cache size, only used in V1 + encoder_cache_size = 16384 + # Whether to perform preemption by swapping or # recomputation. If not specified, we determine the mode as follows: # We use recomputation by default since it incurs lower overhead than diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 161d6b41bfa5f..f05ea195e043d 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -245,6 +245,11 @@ def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: # Do not split, return as tensor of shape [1, fs, hs] return image_embeds.unsqueeze(0) + # If the last split index is the last index in image_tokens, we + # ignore it to avoid empty split tensor + if split_indices[-1] == len(image_tokens): + split_indices = split_indices[:-1] + image_embeds = image_embeds.tensor_split(split_indices.cpu()) return image_embeds diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 03f8814a95356..6cd79d414c978 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -200,6 +200,23 @@ def register_max_image_tokens( """ return self.register_max_multimodal_tokens("image", max_mm_tokens) + def get_max_tokens_per_item_by_modality( + self, + model_config: "ModelConfig", + ) -> Mapping[str, int]: + """ + Get the maximum number of tokens per data item from each modality + for profiling the memory usage of a model. + + Note: + This is currently directly used only in V1. + """ + + return { + key: plugin.get_max_multimodal_tokens(model_config) + for key, plugin in self._plugins.items() + } + def get_max_tokens_by_modality( self, model_config: "ModelConfig", @@ -216,9 +233,9 @@ def get_max_tokens_by_modality( limits_per_plugin = self._limits_by_model[model_config] return { - key: (limits_per_plugin[key] * - plugin.get_max_multimodal_tokens(model_config)) - for key, plugin in self._plugins.items() + key: limits_per_plugin[key] * max_tokens_per_mm_item + for key, max_tokens_per_mm_item in + self.get_max_tokens_per_item_by_modality(model_config).items() } def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int: diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index f76364f64033d..178532e477dae 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -73,14 +73,13 @@ def __init__( # NOTE(woosuk): Here, "encoder" includes the vision encoder (and # projector if needed). Currently, we assume that the encoder also # has the Transformer architecture (e.g., ViT). - # FIXME(woosuk): Below are placeholder values. We need to calculate the - # actual values from the configurations. - self.max_num_encoder_input_tokens = 16384 + self.max_num_encoder_input_tokens = self.scheduler_config.max_num_encoder_input_tokens #noqa: E501 # NOTE(woosuk): For the models without encoder (e.g., text-only models), # the encoder cache will not be initialized and used, regardless of # the cache size. This is because the memory space for the encoder cache # is preallocated in the profiling run. - self.encoder_cache_manager = EncoderCacheManager(cache_size=16384) + self.encoder_cache_manager = EncoderCacheManager( + cache_size=self.scheduler_config.encoder_cache_size) def schedule(self) -> "SchedulerOutput": # NOTE(woosuk) on the scheduling algorithm: diff --git a/vllm/v1/engine/mm_input_mapper.py b/vllm/v1/engine/mm_input_mapper.py index cca27c2218af7..6cdeba6f3f71e 100644 --- a/vllm/v1/engine/mm_input_mapper.py +++ b/vllm/v1/engine/mm_input_mapper.py @@ -54,6 +54,7 @@ def cache_hit_ratio(self, steps): logger.debug("MMInputMapper: cache_hit_ratio = %.2f ", self.mm_cache_hits / self.mm_cache_total) + # TODO: Support modalities beyond image. def process_inputs( self, mm_data: MultiModalDataDict, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 67166fb05085c..c6fab5f05fcb3 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -10,15 +10,16 @@ from vllm.config import CompilationLevel, VllmConfig from vllm.distributed.parallel_state import graph_capture from vllm.forward_context import set_forward_context -from vllm.inputs import INPUT_REGISTRY, InputRegistry +from vllm.inputs import INPUT_REGISTRY from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model -from vllm.multimodal import MultiModalKwargs +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.sampling_params import SamplingType from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, LayerBlockType, cdiv, is_pin_memory_available) from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend, FlashAttentionMetadata) +from vllm.v1.engine.mm_input_mapper import MMInputMapperClient from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch @@ -35,7 +36,6 @@ def __init__( self, vllm_config: VllmConfig, device: torch.device, - input_registry: InputRegistry = INPUT_REGISTRY, ): self.vllm_config = vllm_config self.model_config = vllm_config.model_config @@ -77,7 +77,12 @@ def __init__( self.hidden_size = model_config.get_hidden_size() # Multi-modal data support - self.input_registry = input_registry + self.input_registry = INPUT_REGISTRY + self.mm_registry = MULTIMODAL_REGISTRY + # NOTE: mm_input_mapper is only used for memory profiling. + self.mm_input_mapper = MMInputMapperClient(self.model_config) + self.max_num_encoder_input_tokens = self.scheduler_config.max_num_encoder_input_tokens # noqa: E501 + self.encoder_cache_size = self.scheduler_config.encoder_cache_size # Lazy initialization # self.model: nn.Module # Set after load_model @@ -599,8 +604,6 @@ def _dummy_run( return hidden_states def profile_run(self) -> None: - # TODO(woosuk): Profile the max memory usage of the encoder and - # the encoder cache. # use an empty tensor instead of `None`` to force Dynamo to pass # it by reference, rather by specializing on the value `None`. # the `dtype` argument does not matter, and we use `float32` as @@ -612,6 +615,57 @@ def profile_run(self) -> None: torch.tensor([], dtype=torch.float32, device=self.device) for _ in range(self.num_attn_layers) ] + + # Profile with multimodal encoder & encoder cache. + # TODO (ywang96): generalize this beyond image modality since + # mm_input_mapper only supports image inputs. + if self.is_multimodal_model: + + # Create dummy batch of multimodal inputs. + dummy_request_data = self.input_registry.dummy_data_for_profiling( + model_config=self.model_config, + seq_len=self.max_num_tokens, + mm_registry=self.mm_registry, + ) + dummy_mm_data = dummy_request_data.multi_modal_data + dummy_mm_kwargs, _ = self.mm_input_mapper.process_inputs( + mm_data=dummy_mm_data, + mm_hashes=None, + mm_processor_kwargs=None, + precomputed_mm_inputs=None) + + # NOTE: Currently model is profiled with a single non-text + # modality even when it supports multiple. + max_tokens_per_mm_item = max( + self.mm_registry.get_max_tokens_per_item_by_modality( + self.model_config).values()) + + max_num_mm_items = min( + self.max_num_encoder_input_tokens, + self.encoder_cache_size) // max_tokens_per_mm_item + + # Dummy data definition in V0 may contain multiple multimodal items + # (e.g, multiple images) for a single request, therefore here we + # always replicate first item by max_num_mm_items times since in V1 + # they are scheduled to be processed separately. + batched_dummy_mm_inputs = MultiModalKwargs.batch( + [dummy_mm_kwargs[0]] * max_num_mm_items) + batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs( + batched_dummy_mm_inputs, device=self.device) + + # Run multimodal encoder. + dummy_encoder_outputs = self.model.get_multimodal_embeddings( + **batched_dummy_mm_inputs) + assert len(dummy_encoder_outputs) == max_num_mm_items, ( + "Expected dimension 0 of encoder outputs to match the number " + f"of multimodal data items: {max_num_mm_items}, got " + f"{len(dummy_encoder_outputs)=} instead. This is most likely " + "due to the 'get_multimodal_embeddings' method of the model " + "not implemented correctly.") + + # Cache the dummy encoder outputs. + self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) + # Trigger compilation for general shape. hidden_states = self._dummy_run(self.model, self.max_num_tokens, dummy_kv_caches) @@ -620,6 +674,7 @@ def profile_run(self) -> None: # TODO(woosuk): Consider the memory usage of the sampler. torch.cuda.synchronize() del hidden_states, logits + self.encoder_cache.clear() gc.collect() def capture_model(self) -> None: From e88db68cf5712956f36e77c288699592327b15bd Mon Sep 17 00:00:00 2001 From: wangxiyuan Date: Tue, 17 Dec 2024 14:11:06 +0800 Subject: [PATCH 010/200] [Platform] platform agnostic for EngineArgs initialization (#11225) Signed-off-by: wangxiyuan --- vllm/engine/arg_utils.py | 8 ++------ vllm/platforms/cpu.py | 3 +++ vllm/platforms/cuda.py | 4 ++++ vllm/platforms/hpu.py | 6 ++++++ vllm/platforms/neuron.py | 6 ++++++ vllm/platforms/openvino.py | 3 +++ vllm/platforms/rocm.py | 4 ++++ vllm/platforms/tpu.py | 5 +++++ vllm/platforms/xpu.py | 4 ++++ 9 files changed, 37 insertions(+), 6 deletions(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 06b8542779dc0..f6d276fe7c0c8 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -112,9 +112,7 @@ class EngineArgs: pipeline_parallel_size: int = 1 tensor_parallel_size: int = 1 max_parallel_loading_workers: Optional[int] = None - # NOTE(kzawora): default block size for Gaudi should be 128 - # smaller sizes still work, but very inefficiently - block_size: int = 16 if not current_platform.is_hpu() else 128 + block_size: Optional[int] = None enable_prefix_caching: Optional[bool] = None disable_sliding_window: bool = False use_v2_block_manager: bool = True @@ -1036,9 +1034,7 @@ def create_engine_config(self, self.enable_prefix_caching = False cache_config = CacheConfig( - # neuron needs block_size = max_model_len - block_size=self.block_size if self.device != "neuron" else - (self.max_model_len if self.max_model_len is not None else 0), + block_size=self.block_size, gpu_memory_utilization=self.gpu_memory_utilization, swap_space=self.swap_space, cache_dtype=self.kv_cache_dtype, diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index aad8755d9fcd8..d95a2b4cd5565 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -60,6 +60,9 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: cache_config = vllm_config.cache_config + if cache_config and cache_config.block_size is None: + cache_config.block_size = 16 + kv_cache_space = envs.VLLM_CPU_KVCACHE_SPACE if kv_cache_space >= 0: diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index ae1fd6d5ce068..3c5350b778345 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -137,6 +137,10 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: else: parallel_config.worker_cls = "vllm.worker.worker.Worker" + cache_config = vllm_config.cache_config + if cache_config and cache_config.block_size is None: + cache_config.block_size = 16 + # NVML utils # Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`, diff --git a/vllm/platforms/hpu.py b/vllm/platforms/hpu.py index 2b947d280f9f8..0a44f2b74163a 100644 --- a/vllm/platforms/hpu.py +++ b/vllm/platforms/hpu.py @@ -48,6 +48,12 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: if parallel_config.worker_cls == "auto": parallel_config.worker_cls = "vllm.worker.hpu_worker.HPUWorker" + # NOTE(kzawora): default block size for Gaudi should be 128 + # smaller sizes still work, but very inefficiently + cache_config = vllm_config.cache_config + if cache_config and cache_config.block_size is None: + cache_config.block_size = 128 + @classmethod def is_pin_memory_available(cls): logger.warning("Pin memory is not supported on HPU.") diff --git a/vllm/platforms/neuron.py b/vllm/platforms/neuron.py index 86113523385f6..a4bbbd27c8a89 100644 --- a/vllm/platforms/neuron.py +++ b/vllm/platforms/neuron.py @@ -33,6 +33,12 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: parallel_config.worker_cls = \ "vllm.worker.neuron_worker.NeuronWorker" + cache_config = vllm_config.cache_config + if cache_config: + # neuron needs block_size = max_model_len + vllm_config.cache_config.block_size = \ + vllm_config.model_config.max_model_len + @classmethod def is_pin_memory_available(cls) -> bool: logger.warning("Pin memory is not supported on Neuron.") diff --git a/vllm/platforms/openvino.py b/vllm/platforms/openvino.py index ccd94e8adb3b1..16eb8dc81efc2 100644 --- a/vllm/platforms/openvino.py +++ b/vllm/platforms/openvino.py @@ -87,6 +87,9 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: # check and update cache config ov_core = ov.Core() cache_config = vllm_config.cache_config + if cache_config and cache_config.block_size is None: + cache_config.block_size = 16 + if envs.VLLM_OPENVINO_CPU_KV_CACHE_PRECISION == "u8": if not OpenVinoPlatform.is_openvino_cpu(): logger.info("VLLM_OPENVINO_CPU_KV_CACHE_PRECISION is" diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 0133f26a0b1bc..7778b565372cb 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -84,6 +84,10 @@ def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: @classmethod def check_and_update_config(cls, vllm_config: VllmConfig) -> None: + cache_config = vllm_config.cache_config + if cache_config and cache_config.block_size is None: + cache_config.block_size = 16 + parallel_config = vllm_config.parallel_config scheduler_config = vllm_config.scheduler_config if parallel_config.worker_cls == "auto": diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 10d874349f36b..77f5c8401424b 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -46,6 +46,11 @@ def inference_mode(cls): @classmethod def check_and_update_config(cls, vllm_config: VllmConfig) -> None: from vllm.config import CompilationLevel + + cache_config = vllm_config.cache_config + if cache_config and cache_config.block_size is None: + cache_config.block_size = 16 + compilation_config = vllm_config.compilation_config if compilation_config.level == CompilationLevel.NO_COMPILATION: # TPU does not support NO_COMPILATION diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index c20190e789d7e..78e17c2afec65 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -51,6 +51,10 @@ def inference_mode(): @classmethod def check_and_update_config(cls, vllm_config: VllmConfig) -> None: + cache_config = vllm_config.cache_config + if cache_config and cache_config.block_size is None: + cache_config.block_size = 16 + # check and update model config model_config = vllm_config.model_config if model_config.dtype == torch.bfloat16: From 2bfdbf2a36256bb08547cea3d4ef83b5d27c4b04 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Tue, 17 Dec 2024 01:11:33 -0500 Subject: [PATCH 011/200] [V1][Core] Use weakref.finalize instead of atexit (#11242) Signed-off-by: Tyler Michael Smith --- vllm/v1/engine/core_client.py | 13 ++----------- vllm/v1/executor/multiproc_executor.py | 10 +++------- 2 files changed, 5 insertions(+), 18 deletions(-) diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index ff25a9b2e9cac..d56fcbdb1e7c4 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -1,5 +1,5 @@ -import atexit import os +import weakref from typing import List, Optional import msgspec @@ -165,15 +165,9 @@ def __init__( ready_path=ready_path, # type: ignore[misc] **kwargs, ) - atexit.register(self.shutdown) + self._finalizer = weakref.finalize(self, self.shutdown) def shutdown(self): - # During final garbage collection in process shutdown, atexit may be - # None. - if atexit: - # in case shutdown gets called via __del__ first - atexit.unregister(self.shutdown) - # Shut down the zmq context. self.ctx.destroy(linger=0) @@ -197,9 +191,6 @@ def shutdown(self): os.remove(socket_file) self.proc_handle = None - def __del__(self): - self.shutdown() - class SyncMPClient(MPClient): """Synchronous client for multi-proc EngineCore.""" diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 17441dacdc5cf..128101aa6956d 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -1,9 +1,9 @@ -import atexit import os import pickle import signal import sys import time +import weakref from dataclasses import dataclass from enum import Enum, auto from multiprocessing.process import BaseProcess @@ -37,7 +37,7 @@ class MultiprocExecutor(Executor): def __init__(self, vllm_config: VllmConfig) -> None: # Call self.shutdown at exit to clean up # and ensure workers will be terminated. - atexit.register(self.shutdown) + self._finalizer = weakref.finalize(self, self.shutdown) self.vllm_config = vllm_config self.parallel_config = vllm_config.parallel_config @@ -195,14 +195,10 @@ def _cleanup_sockets(self): os.remove(socket_path) def shutdown(self): - if atexit: - # in case shutdown was called explicitly, we don't need to call it - # again - atexit.unregister(self.shutdown) """Properly shut down the executor and its workers""" if getattr(self, 'shutting_down', False): self.shutting_down = True - for w in self.workers: #TODO: not sure if needed + for w in self.workers: w.worker_response_mq = None self._ensure_worker_termination() From 02222a0256f60319f5bcd56d1d036a943d6334f8 Mon Sep 17 00:00:00 2001 From: Roger Wang <136131678+ywang96@users.noreply.github.com> Date: Mon, 16 Dec 2024 22:57:02 -0800 Subject: [PATCH 012/200] [Misc] Kernel Benchmark for `RMSNorm` (#11241) Signed-off-by: Roger Wang Co-authored-by: Xiaoyu Zhang --- benchmarks/kernels/benchmark_rmsnorm.py | 262 ++++++++++++++++++++++++ 1 file changed, 262 insertions(+) create mode 100644 benchmarks/kernels/benchmark_rmsnorm.py diff --git a/benchmarks/kernels/benchmark_rmsnorm.py b/benchmarks/kernels/benchmark_rmsnorm.py new file mode 100644 index 0000000000000..baa5de0fff1bd --- /dev/null +++ b/benchmarks/kernels/benchmark_rmsnorm.py @@ -0,0 +1,262 @@ +import itertools +from typing import Optional, Tuple, Union + +import torch +import triton +from flashinfer.norm import fused_add_rmsnorm, rmsnorm +from torch import nn + +from vllm import _custom_ops as vllm_ops + + +class HuggingFaceRMSNorm(nn.Module): + + def __init__(self, hidden_size: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + orig_dtype = x.dtype + x = x.to(torch.float32) + if residual is not None: + x = x + residual.to(torch.float32) + residual = x.to(orig_dtype) + + variance = x.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + self.variance_epsilon) + x = x.to(orig_dtype) * self.weight + if residual is None: + return x + else: + return x, residual + + +def rmsnorm_naive( + x: torch.Tensor, + weight: torch.Tensor, + residual: Optional[torch.Tensor] = None, + eps: float = 1e-6, +): + naive_norm = HuggingFaceRMSNorm(x.shape[-1], eps=eps) + naive_norm.weight = nn.Parameter(weight) + naive_norm = naive_norm.to(x.device) + + orig_shape = x.shape + x = x.view(-1, x.shape[-1]) + if residual is not None: + residual = residual.view(-1, residual.shape[-1]) + + output = naive_norm(x, residual) + + if isinstance(output, tuple): + output = (output[0].view(orig_shape), output[1].view(orig_shape)) + else: + output = output.view(orig_shape) + return output + + +def rmsnorm_flashinfer( + x: torch.Tensor, + weight: torch.Tensor, + residual: Optional[torch.Tensor] = None, + eps: float = 1e-6, +): + orig_shape = x.shape + x = x.view(-1, x.shape[-1]) + if residual is not None: + residual = residual.view(-1, residual.shape[-1]) + + if residual is not None: + fused_add_rmsnorm(x, residual, weight, eps) + output = (x, residual) + else: + output = rmsnorm(x, weight, eps) + + if isinstance(output, tuple): + output = (output[0].view(orig_shape), output[1].view(orig_shape)) + else: + output = output.view(orig_shape) + return output + + +def rmsnorm_vllm( + x: torch.Tensor, + weight: torch.Tensor, + residual: Optional[torch.Tensor] = None, + eps: float = 1e-6, +): + orig_shape = x.shape + x = x.view(-1, x.shape[-1]) + if residual is not None: + residual = residual.view(-1, residual.shape[-1]) + + if residual is not None: + vllm_ops.fused_add_rms_norm(x, residual, weight, eps) + output = (x, residual) + else: + out = torch.empty_like(x) + vllm_ops.rms_norm(out, x, weight, eps) + output = out + + if isinstance(output, tuple): + output = (output[0].view(orig_shape), output[1].view(orig_shape)) + else: + output = output.view(orig_shape) + return output + + +def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True): + dtype = torch.bfloat16 + x = torch.randn(batch_size, + seq_len, + hidden_size, + dtype=dtype, + device="cuda") + weight = torch.ones(hidden_size, dtype=dtype, device="cuda") + residual = torch.randn_like(x) if use_residual else None + + output_naive = rmsnorm_naive( + x.clone(), weight, + residual.clone() if residual is not None else None) + output_flashinfer = rmsnorm_flashinfer( + x.clone(), weight, + residual.clone() if residual is not None else None) + output_vllm = rmsnorm_vllm( + x.clone(), weight, + residual.clone() if residual is not None else None) + + if use_residual: + output_naive = output_naive[0] + output_flashinfer = output_flashinfer[0] + output_vllm = output_vllm[0] + + print(f"Naive output={output_naive}") + print(f"FlashInfer output={output_flashinfer}") + print(f"VLLM output={output_vllm}") + + if torch.allclose(output_naive, output_flashinfer, atol=1e-2, + rtol=1e-2) and torch.allclose( + output_naive, output_vllm, atol=1e-2, rtol=1e-2): + print("✅ All implementations match") + else: + print("❌ Implementations differ") + + +batch_size_range = [2**i for i in range(0, 7, 2)] +seq_length_range = [2**i for i in range(6, 11, 1)] +head_num_range = [32, 48] +configs = list( + itertools.product(head_num_range, batch_size_range, seq_length_range)) + + +def get_benchmark(use_residual): + + @triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["head_num", "batch_size", "seq_len"], + x_vals=[list(_) for _ in configs], + line_arg="provider", + line_vals=["huggingface", "flashinfer", "vllm"], + line_names=["HuggingFace", "FlashInfer", "vLLM"], + styles=[("blue", "-"), ("green", "-"), ("red", "-")], + ylabel="us", + plot_name= + f"rmsnorm-perf-{'with' if use_residual else 'without'}-residual", + args={}, + )) + def benchmark(head_num, batch_size, seq_len, provider): + dtype = torch.bfloat16 + hidden_size = head_num * 128 # assuming head_dim = 128 + + x = torch.randn(batch_size, + seq_len, + hidden_size, + dtype=dtype, + device="cuda") + weight = torch.ones(hidden_size, dtype=dtype, device="cuda") + residual = torch.randn_like(x) if use_residual else None + + quantiles = [0.5, 0.2, 0.8] + + if provider == "huggingface": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: rmsnorm_naive( + x.clone(), + weight, + residual.clone() if residual is not None else None, + ), + quantiles=quantiles, + ) + elif provider == "flashinfer": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: rmsnorm_flashinfer( + x.clone(), + weight, + residual.clone() if residual is not None else None, + ), + quantiles=quantiles, + ) + else: + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: rmsnorm_vllm( + x.clone(), + weight, + residual.clone() if residual is not None else None, + ), + quantiles=quantiles, + ) + + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + return benchmark + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--batch-size", + type=int, + default=4, + help="Batch size", + ) + parser.add_argument( + "--seq-len", + type=int, + default=128, + help="Sequence length", + ) + parser.add_argument( + "--hidden-size", + type=int, + default=4096, + help="Hidden size (2nd dimension) of the sequence", + ) + parser.add_argument("--use-residual", + action="store_true", + help="Whether to use residual connection") + parser.add_argument( + "--save-path", + type=str, + default="./configs/rmsnorm/", + help="Path to save rmsnorm benchmark results", + ) + + args = parser.parse_args() + + # Run correctness test + calculate_diff(batch_size=args.batch_size, + seq_len=args.seq_len, + hidden_size=args.hidden_size, + use_residual=args.use_residual) + + # Get the benchmark function with proper use_residual setting + benchmark = get_benchmark(args.use_residual) + # Run performance benchmark + benchmark.run(print_data=True, save_path=args.save_path) From f9ecbb18bf03338a4272c933a49a87021363b048 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Tue, 17 Dec 2024 16:37:04 +0800 Subject: [PATCH 013/200] [Misc] Allow passing logits_soft_cap for xformers backend (#11252) Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/attention/backends/xformers.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index e2e989efb020c..3e59b3603d2c6 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -17,9 +17,7 @@ is_all_cross_attn_metadata_set, is_all_encoder_attn_metadata_set) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) -from vllm.logger import init_logger - -logger = init_logger(__name__) +from vllm.utils import print_warning_once class XFormersBackend(AttentionBackend): @@ -386,8 +384,8 @@ def __init__( raise ValueError( "XFormers does not support block-sparse attention.") if logits_soft_cap is not None: - raise ValueError( - "XFormers does not support attention logits soft capping.") + print_warning_once("XFormers does not support logits soft cap. " + "Outputs may be slightly off.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) From 2d1b9baa8f57fc59912c7bcd07fd630fb9d72c9d Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Tue, 17 Dec 2024 13:26:32 -0700 Subject: [PATCH 014/200] [Bugfix] Fix request cancellation without polling (#11190) --- tests/entrypoints/openai/test_basic.py | 51 ++++++++++++++++ tests/test_utils.py | 6 +- tests/utils.py | 11 ++-- vllm/engine/async_llm_engine.py | 46 +++++++++------ vllm/entrypoints/api_server.py | 11 ++-- vllm/entrypoints/openai/api_server.py | 8 +++ vllm/entrypoints/openai/serving_chat.py | 5 -- vllm/entrypoints/openai/serving_completion.py | 3 +- vllm/entrypoints/openai/serving_embedding.py | 5 +- vllm/entrypoints/openai/serving_score.py | 5 +- vllm/entrypoints/utils.py | 57 ++++++++++++++++++ vllm/utils.py | 59 ++----------------- 12 files changed, 164 insertions(+), 103 deletions(-) create mode 100644 vllm/entrypoints/utils.py diff --git a/tests/entrypoints/openai/test_basic.py b/tests/entrypoints/openai/test_basic.py index 4616f363cc04a..547c1fd020928 100644 --- a/tests/entrypoints/openai/test_basic.py +++ b/tests/entrypoints/openai/test_basic.py @@ -1,6 +1,8 @@ +import asyncio from http import HTTPStatus from typing import List +import openai import pytest import pytest_asyncio import requests @@ -103,3 +105,52 @@ async def test_check_health(server: RemoteOpenAIServer): response = requests.get(server.url_for("health")) assert response.status_code == HTTPStatus.OK + + +@pytest.mark.parametrize( + "server_args", + [ + pytest.param(["--max-model-len", "10100"], + id="default-frontend-multiprocessing"), + pytest.param( + ["--disable-frontend-multiprocessing", "--max-model-len", "10100"], + id="disable-frontend-multiprocessing") + ], + indirect=True, +) +@pytest.mark.asyncio +async def test_request_cancellation(server: RemoteOpenAIServer): + # clunky test: send an ungodly amount of load in with short timeouts + # then ensure that it still responds quickly afterwards + + chat_input = [{"role": "user", "content": "Write a long story"}] + client = server.get_async_client(timeout=0.5) + tasks = [] + # Request about 2 million tokens + for _ in range(200): + task = asyncio.create_task( + client.chat.completions.create(messages=chat_input, + model=MODEL_NAME, + max_tokens=10000, + extra_body={"min_tokens": 10000})) + tasks.append(task) + + done, pending = await asyncio.wait(tasks, + return_when=asyncio.ALL_COMPLETED) + + # Make sure all requests were sent to the server and timed out + # (We don't want to hide other errors like 400s that would invalidate this + # test) + assert len(pending) == 0 + for d in done: + with pytest.raises(openai.APITimeoutError): + d.result() + + # If the server had not cancelled all the other requests, then it would not + # be able to respond to this one within the timeout + client = server.get_async_client(timeout=5) + response = await client.chat.completions.create(messages=chat_input, + model=MODEL_NAME, + max_tokens=10) + + assert len(response.choices) == 1 diff --git a/tests/test_utils.py b/tests/test_utils.py index 0bc9e5bc32a46..32a6b0aed66aa 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,7 +1,6 @@ import asyncio import os import socket -from functools import partial from typing import AsyncIterator, Tuple import pytest @@ -26,10 +25,7 @@ async def mock_async_iterator(idx: int): print(f"iterator {idx} cancelled") iterators = [mock_async_iterator(i) for i in range(3)] - merged_iterator = merge_async_iterators(*iterators, - is_cancelled=partial(asyncio.sleep, - 0, - result=False)) + merged_iterator = merge_async_iterators(*iterators) async def stream_output(generator: AsyncIterator[Tuple[int, str]]): async for idx, output in generator: diff --git a/tests/utils.py b/tests/utils.py index afeb708f3bcdc..bf3d88194e4ca 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -163,12 +163,11 @@ def get_client(self): api_key=self.DUMMY_API_KEY, ) - def get_async_client(self): - return openai.AsyncOpenAI( - base_url=self.url_for("v1"), - api_key=self.DUMMY_API_KEY, - max_retries=0, - ) + def get_async_client(self, **kwargs): + return openai.AsyncOpenAI(base_url=self.url_for("v1"), + api_key=self.DUMMY_API_KEY, + max_retries=0, + **kwargs) def _test_completion( diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 32396fd10188d..f50e20cf70323 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -1065,16 +1065,20 @@ async def generate( >>> # Process and return the final output >>> ... """ - async for output in await self.add_request( - request_id, - prompt, - sampling_params, - lora_request=lora_request, - trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request, - priority=priority, - ): - yield LLMEngine.validate_output(output, RequestOutput) + try: + async for output in await self.add_request( + request_id, + prompt, + sampling_params, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + priority=priority, + ): + yield LLMEngine.validate_output(output, RequestOutput) + except asyncio.CancelledError: + await self.abort(request_id) + raise async def encode( self, @@ -1147,15 +1151,19 @@ async def encode( >>> # Process and return the final output >>> ... """ - async for output in await self.add_request( - request_id, - prompt, - pooling_params, - lora_request=lora_request, - trace_headers=trace_headers, - priority=priority, - ): - yield LLMEngine.validate_output(output, PoolingRequestOutput) + try: + async for output in await self.add_request( + request_id, + prompt, + pooling_params, + lora_request=lora_request, + trace_headers=trace_headers, + priority=priority, + ): + yield LLMEngine.validate_output(output, PoolingRequestOutput) + except asyncio.CancelledError: + await self.abort(request_id) + raise async def abort(self, request_id: str) -> None: """Abort a request. diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index ea3c93f733038..95da1c6e7b9bf 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -17,11 +17,11 @@ from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.launcher import serve_http +from vllm.entrypoints.utils import with_cancellation from vllm.logger import init_logger from vllm.sampling_params import SamplingParams from vllm.usage.usage_lib import UsageContext -from vllm.utils import (FlexibleArgumentParser, iterate_with_cancellation, - random_uuid) +from vllm.utils import FlexibleArgumentParser, random_uuid from vllm.version import __version__ as VLLM_VERSION logger = init_logger("vllm.entrypoints.api_server") @@ -47,6 +47,11 @@ async def generate(request: Request) -> Response: - other fields: the sampling parameters (See `SamplingParams` for details). """ request_dict = await request.json() + return await _generate(request_dict, raw_request=request) + + +@with_cancellation +async def _generate(request_dict: dict, raw_request: Request) -> Response: prompt = request_dict.pop("prompt") stream = request_dict.pop("stream", False) sampling_params = SamplingParams(**request_dict) @@ -54,8 +59,6 @@ async def generate(request: Request) -> Response: assert engine is not None results_generator = engine.generate(prompt, sampling_params, request_id) - results_generator = iterate_with_cancellation( - results_generator, is_cancelled=request.is_disconnected) # Streaming case async def stream_results() -> AsyncGenerator[bytes, None]: diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 14e3a34ce141c..00e2d1a56f160 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -59,6 +59,7 @@ from vllm.entrypoints.openai.serving_tokenization import ( OpenAIServingTokenization) from vllm.entrypoints.openai.tool_parsers import ToolParserManager +from vllm.entrypoints.utils import with_cancellation from vllm.logger import init_logger from vllm.usage.usage_lib import UsageContext from vllm.utils import (FlexibleArgumentParser, get_open_zmq_ipc_path, @@ -311,6 +312,7 @@ async def health(raw_request: Request) -> Response: @router.post("/tokenize") +@with_cancellation async def tokenize(request: TokenizeRequest, raw_request: Request): handler = tokenization(raw_request) @@ -325,6 +327,7 @@ async def tokenize(request: TokenizeRequest, raw_request: Request): @router.post("/detokenize") +@with_cancellation async def detokenize(request: DetokenizeRequest, raw_request: Request): handler = tokenization(raw_request) @@ -353,6 +356,7 @@ async def show_version(): @router.post("/v1/chat/completions") +@with_cancellation async def create_chat_completion(request: ChatCompletionRequest, raw_request: Request): handler = chat(raw_request) @@ -373,6 +377,7 @@ async def create_chat_completion(request: ChatCompletionRequest, @router.post("/v1/completions") +@with_cancellation async def create_completion(request: CompletionRequest, raw_request: Request): handler = completion(raw_request) if handler is None: @@ -390,6 +395,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request): @router.post("/v1/embeddings") +@with_cancellation async def create_embedding(request: EmbeddingRequest, raw_request: Request): handler = embedding(raw_request) if handler is None: @@ -407,6 +413,7 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request): @router.post("/score") +@with_cancellation async def create_score(request: ScoreRequest, raw_request: Request): handler = score(raw_request) if handler is None: @@ -424,6 +431,7 @@ async def create_score(request: ScoreRequest, raw_request: Request): @router.post("/v1/score") +@with_cancellation async def create_score_v1(request: ScoreRequest, raw_request: Request): logger.warning( "To indicate that Score API is not part of standard OpenAI API, we " diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 527418c635093..81bce0dd370bb 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -32,7 +32,6 @@ from vllm.sequence import Logprob from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.transformers_utils.tokenizers import maybe_serialize_tool_calls -from vllm.utils import iterate_with_cancellation logger = init_logger(__name__) @@ -234,10 +233,6 @@ async def create_chat_completion( assert len(generators) == 1 result_generator, = generators - if raw_request: - result_generator = iterate_with_cancellation( - result_generator, raw_request.is_disconnected) - # Streaming response if request.stream: return self.chat_completion_stream_generator( diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index bd39a4c42e938..5cf9df92e296e 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -159,8 +159,7 @@ async def create_completion( # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) - result_generator = merge_async_iterators( - *generators, is_cancelled=raw_request.is_disconnected) + result_generator = merge_async_iterators(*generators) model_name = self._get_model_name(lora_request) num_prompts = len(engine_prompts) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index fd501ad4f833e..879276646d2ba 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -202,10 +202,7 @@ async def create_embedding( # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) - result_generator = merge_async_iterators( - *generators, - is_cancelled=raw_request.is_disconnected if raw_request else None, - ) + result_generator = merge_async_iterators(*generators) num_prompts = len(engine_prompts) diff --git a/vllm/entrypoints/openai/serving_score.py b/vllm/entrypoints/openai/serving_score.py index 6f5cc14ac37cc..101d170bee4d6 100644 --- a/vllm/entrypoints/openai/serving_score.py +++ b/vllm/entrypoints/openai/serving_score.py @@ -186,10 +186,7 @@ async def create_score( # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) - result_generator = merge_async_iterators( - *generators, - is_cancelled=raw_request.is_disconnected if raw_request else None, - ) + result_generator = merge_async_iterators(*generators) num_prompts = len(engine_prompts) diff --git a/vllm/entrypoints/utils.py b/vllm/entrypoints/utils.py new file mode 100644 index 0000000000000..e8a78d216d0f0 --- /dev/null +++ b/vllm/entrypoints/utils.py @@ -0,0 +1,57 @@ +import asyncio +import functools + +from fastapi import Request + + +async def listen_for_disconnect(request: Request) -> None: + """Returns if a disconnect message is received""" + while True: + message = await request.receive() + if message["type"] == "http.disconnect": + break + + +def with_cancellation(handler_func): + """Decorator that allows a route handler to be cancelled by client + disconnections. + + This does _not_ use request.is_disconnected, which does not work with + middleware. Instead this follows the pattern from + starlette.StreamingResponse, which simultaneously awaits on two tasks- one + to wait for an http disconnect message, and the other to do the work that we + want done. When the first task finishes, the other is cancelled. + + A core assumption of this method is that the body of the request has already + been read. This is a safe assumption to make for fastapi handlers that have + already parsed the body of the request into a pydantic model for us. + This decorator is unsafe to use elsewhere, as it will consume and throw away + all incoming messages for the request while it looks for a disconnect + message. + + In the case where a `StreamingResponse` is returned by the handler, this + wrapper will stop listening for disconnects and instead the response object + will start listening for disconnects. + """ + + # Functools.wraps is required for this wrapper to appear to fastapi as a + # normal route handler, with the correct request type hinting. + @functools.wraps(handler_func) + async def wrapper(*args, **kwargs): + + # The request is either the second positional arg or `raw_request` + request = args[1] if len(args) > 1 else kwargs["raw_request"] + + handler_task = asyncio.create_task(handler_func(*args, **kwargs)) + cancellation_task = asyncio.create_task(listen_for_disconnect(request)) + + done, pending = await asyncio.wait([handler_task, cancellation_task], + return_when=asyncio.FIRST_COMPLETED) + for task in pending: + task.cancel() + + if handler_task in done: + return handler_task.result() + return None + + return wrapper diff --git a/vllm/utils.py b/vllm/utils.py index 73d2ae25f15ca..38c7dea6d2d3d 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -20,7 +20,7 @@ import uuid import warnings import weakref -from asyncio import FIRST_COMPLETED, AbstractEventLoop, Future, Task +from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task from collections import UserDict, defaultdict from collections.abc import Iterable, Mapping from dataclasses import dataclass, field @@ -370,72 +370,23 @@ def _next_task(iterator: AsyncGenerator[T, None], return loop.create_task(iterator.__anext__()) # type: ignore[arg-type] -async def iterate_with_cancellation( - iterator: AsyncGenerator[T, None], - is_cancelled: Callable[[], Awaitable[bool]], -) -> AsyncGenerator[T, None]: - """Convert async iterator into one that polls the provided function - at least once per second to check for client cancellation. - """ - - loop = asyncio.get_running_loop() - - awaits: List[Future[T]] = [_next_task(iterator, loop)] - next_cancel_check: float = 0 - while True: - done, pending = await asyncio.wait(awaits, timeout=1.5) - - # Check for cancellation at most once per second - time_now = time.time() - if time_now >= next_cancel_check: - if await is_cancelled(): - with contextlib.suppress(BaseException): - awaits[0].cancel() - await iterator.aclose() - raise asyncio.CancelledError("client cancelled") - next_cancel_check = time_now + 1 - - if done: - try: - item = await awaits[0] - awaits[0] = _next_task(iterator, loop) - yield item - except StopAsyncIteration: - # we are done - return - - async def merge_async_iterators( - *iterators: AsyncGenerator[T, None], - is_cancelled: Optional[Callable[[], Awaitable[bool]]] = None, -) -> AsyncGenerator[Tuple[int, T], None]: + *iterators: AsyncGenerator[T, + None], ) -> AsyncGenerator[Tuple[int, T], None]: """Merge multiple asynchronous iterators into a single iterator. This method handle the case where some iterators finish before others. When it yields, it yields a tuple (i, item) where i is the index of the iterator that yields the item. - - It also optionally polls a provided function at least once per second - to check for client cancellation. """ loop = asyncio.get_running_loop() awaits = {_next_task(pair[1], loop): pair for pair in enumerate(iterators)} - timeout = None if is_cancelled is None else 1.5 - next_cancel_check: float = 0 try: while awaits: - done, pending = await asyncio.wait(awaits.keys(), - return_when=FIRST_COMPLETED, - timeout=timeout) - if is_cancelled is not None: - # Check for cancellation at most once per second - time_now = time.time() - if time_now >= next_cancel_check: - if await is_cancelled(): - raise asyncio.CancelledError("client cancelled") - next_cancel_check = time_now + 1 + done, _ = await asyncio.wait(awaits.keys(), + return_when=FIRST_COMPLETED) for d in done: pair = awaits.pop(d) try: From c77eb8a33ceb62858d951ffef87ae626a0d09973 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Tue, 17 Dec 2024 19:34:06 -0500 Subject: [PATCH 015/200] [Bugfix] Set temperature=0.7 in test_guided_choice_chat (#11264) --- tests/entrypoints/openai/test_chat.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/entrypoints/openai/test_chat.py b/tests/entrypoints/openai/test_chat.py index 8d23a2be6f9bb..47c521a9b5eb5 100644 --- a/tests/entrypoints/openai/test_chat.py +++ b/tests/entrypoints/openai/test_chat.py @@ -482,6 +482,7 @@ async def test_guided_choice_chat(client: openai.AsyncOpenAI, model=MODEL_NAME, messages=messages, max_completion_tokens=10, + temperature=0.7, extra_body=dict(guided_choice=sample_guided_choice, guided_decoding_backend=guided_decoding_backend)) choice1 = chat_completion.choices[0].message.content @@ -496,6 +497,7 @@ async def test_guided_choice_chat(client: openai.AsyncOpenAI, model=MODEL_NAME, messages=messages, max_completion_tokens=10, + temperature=0.7, extra_body=dict(guided_choice=sample_guided_choice, guided_decoding_backend=guided_decoding_backend)) choice2 = chat_completion.choices[0].message.content From bf8717ebaea8d74279df84fbe127ad22cf62e219 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Tue, 17 Dec 2024 16:37:59 -0800 Subject: [PATCH 016/200] [V1] Prefix caching for vision language models (#11187) Signed-off-by: Cody Yu --- tests/v1/core/test_prefix_caching.py | 88 +++++++++++++++++++- tests/v1/engine/test_engine_args.py | 15 ---- vllm/engine/arg_utils.py | 27 ++++--- vllm/inputs/data.py | 20 +++++ vllm/multimodal/inputs.py | 3 + vllm/v1/core/kv_cache_manager.py | 74 +++++++++++------ vllm/v1/core/kv_cache_utils.py | 115 ++++++++++++++++++++++++--- vllm/v1/core/scheduler.py | 2 + vllm/v1/engine/async_llm.py | 10 ++- vllm/v1/engine/core.py | 8 +- vllm/v1/engine/llm_engine.py | 9 ++- vllm/v1/engine/mm_input_mapper.py | 33 ++++---- vllm/v1/engine/processor.py | 12 +-- vllm/v1/request.py | 24 +++++- 14 files changed, 342 insertions(+), 98 deletions(-) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 00f7b0fcfe1dc..ed04f0a373c51 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -2,16 +2,23 @@ import pytest from vllm.inputs import token_inputs +from vllm.multimodal.inputs import PlaceholderRange from vllm.sampling_params import SamplingParams from vllm.utils import cdiv from vllm.v1.core.kv_cache_manager import KVCacheManager, Request from vllm.v1.core.kv_cache_utils import KVCacheBlock, hash_block_tokens -def make_request(request_id, prompt_token_ids): +def make_request(request_id, + prompt_token_ids, + mm_positions=None, + mm_hashes=None): return Request( request_id=request_id, - inputs=token_inputs(prompt_token_ids=prompt_token_ids), + inputs=token_inputs(prompt_token_ids=prompt_token_ids, + multi_modal_placeholders={"image": mm_positions} + if mm_positions else None, + multi_modal_hashes=mm_hashes), sampling_params=SamplingParams(max_tokens=17), eos_token_id=100, arrival_time=0, @@ -38,6 +45,7 @@ def test_prefill(): all_token_ids = common_token_ids + unique_token_ids req0 = make_request("0", all_token_ids) computed_blocks = manager.get_computed_blocks(req0) + assert len(req0.kv_block_hashes) == 3 assert not computed_blocks blocks = manager.allocate_slots(req0, 55, computed_blocks) assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4] @@ -61,6 +69,7 @@ def test_prefill(): unique_token_ids = [3] * 5 req1 = make_request("1", common_token_ids + unique_token_ids) computed_blocks = manager.get_computed_blocks(req1) + assert len(req1.kv_block_hashes) == 3 assert [b.block_id for b in computed_blocks] == [0, 1, 2] num_new_tokens = 53 - 3 * 16 blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks) @@ -90,6 +99,7 @@ def test_prefill(): unique_token_ids = [3] * 6 req2 = make_request("2", common_token_ids + unique_token_ids) computed_block = manager.get_computed_blocks(req2) + assert len(req2.kv_block_hashes) == 3 assert [b.block_id for b in computed_block] == [0, 1, 2] num_new_tokens = 53 - 3 * 16 blocks = manager.allocate_slots(req2, num_new_tokens, computed_blocks) @@ -416,3 +426,77 @@ def test_cache_blocks(): ) assert len(manager.cached_block_hash_to_block) == 3 assert blocks[0].block_hash is not None + + +def test_mm_prefix_caching(): + """ + This tests that the multi-modal prefix caching is correct. + """ + manager = KVCacheManager( + block_size=16, + num_gpu_blocks=10, + max_model_len=8192, + sliding_window=None, + enable_caching=True, + num_preallocate_tokens=16, + ) + + # Common prompt tokens (T is text tokens and P is image placeholder tokens) + # [T,...,T, P0,...,P0], [P0,...,P0,T,...,T,P1,...,P1], [P1,...,P1] + common_token_ids = list(range(10)) + [-1] * 6 + common_token_ids += [-1] * 4 + list(range(10, 20)) + [-1] * 2 + common_token_ids += [-1] * 16 + + common_mm_positions = [ + PlaceholderRange(offset=11, length=10), + PlaceholderRange(offset=30, length=18), + ] + common_mm_hashes = ["aaa", "bbb"] + + # A unique image plus some text tokens. + unique_token_ids = [-1] * 7 + [100] * 4 + all_token_ids = common_token_ids + unique_token_ids + mm_positions = common_mm_positions + [ + PlaceholderRange(offset=48, length=7) + ] + mm_hashes = common_mm_hashes + ["ccc"] + req0 = make_request("0", + all_token_ids, + mm_positions=mm_positions, + mm_hashes=mm_hashes) + computed_blocks = manager.get_computed_blocks(req0) + + # Completed block should have hashes with extra keys. + assert not computed_blocks + assert len(req0.kv_block_hashes) == 3 + assert req0.kv_block_hashes[0].extra_keys == (("aaa", 0), ) + assert req0.kv_block_hashes[1].extra_keys == (("aaa", 5), ("bbb", 0)) + assert req0.kv_block_hashes[2].extra_keys == (("bbb", 2), ) + + blocks = manager.allocate_slots(req0, 59, computed_blocks) + assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4] + req0.num_computed_tokens = 59 + + # Append slots without allocating a new block. + for _ in range(5): + req0.append_output_token_ids(8) + new_blocks = manager.append_slots(req0, 5) + assert new_blocks is not None and len(new_blocks) == 0 + + # The just completed block should have hashes with extra keys. + assert len(req0.kv_block_hashes) == 4 + assert req0.kv_block_hashes[3].extra_keys == (("ccc", 0), ) + + # Cache hit. + unique_token_ids = [-1] * 7 + [200] * 5 + all_token_ids = common_token_ids + unique_token_ids + mm_positions = common_mm_positions + [ + PlaceholderRange(offset=48, length=7) + ] + mm_hashes = common_mm_hashes + ["ccc"] + req1 = make_request("1", + all_token_ids, + mm_positions=mm_positions, + mm_hashes=mm_hashes) + computed_blocks = manager.get_computed_blocks(req1) + assert len(computed_blocks) == 3 diff --git a/tests/v1/engine/test_engine_args.py b/tests/v1/engine/test_engine_args.py index ac5e7dde525a7..ff38a4568ecb1 100644 --- a/tests/v1/engine/test_engine_args.py +++ b/tests/v1/engine/test_engine_args.py @@ -31,14 +31,6 @@ def test_prefix_caching_from_cli(): assert engine_args.enable_prefix_caching -def test_defaults(): - engine_args = EngineArgs(model="facebook/opt-125m") - - # Assert V1 defaults - assert (engine_args.enable_prefix_caching - ), "V1 turns on prefix caching by default" - - def test_defaults_with_usage_context(): engine_args = EngineArgs(model="facebook/opt-125m") vllm_config: VllmConfig = engine_args.create_engine_config( @@ -52,10 +44,3 @@ def test_defaults_with_usage_context(): UsageContext.OPENAI_API_SERVER) assert vllm_config.scheduler_config.max_num_seqs == 1024 assert vllm_config.scheduler_config.max_num_batched_tokens == 2048 - - -def test_prefix_cache_disabled_with_multimodel(): - engine_args = EngineArgs(model="llava-hf/llava-1.5-7b-hf") - - vllm_config = engine_args.create_engine_config(UsageContext.LLM_CLASS) - assert not vllm_config.cache_config.enable_prefix_caching diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index f6d276fe7c0c8..674577f23eba6 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -205,6 +205,7 @@ def __post_init__(self): # by user. if self.enable_prefix_caching is None: self.enable_prefix_caching = bool(envs.VLLM_USE_V1) + # Override max_num_seqs if it's not set by user. if self.max_num_seqs is None: self.max_num_seqs = 256 if not envs.VLLM_USE_V1 else 1024 @@ -1026,11 +1027,11 @@ def create_engine_config(self, device_config = DeviceConfig(device=self.device) model_config = self.create_model_config() - if model_config.is_multimodal_model: - if self.enable_prefix_caching: - logger.warning( - "--enable-prefix-caching is currently not " - "supported for multimodal models and has been disabled.") + if (model_config.is_multimodal_model and not envs.VLLM_USE_V1 + and self.enable_prefix_caching): + logger.warning("--enable-prefix-caching is currently not " + "supported for multimodal models in v0 and " + "has been disabled.") self.enable_prefix_caching = False cache_config = CacheConfig( @@ -1249,11 +1250,14 @@ def _override_v1_engine_args(self, usage_context: UsageContext) -> None: # When no user override, set the default values based on the usage # context. # TODO(woosuk): Tune the default values for different hardware. - if self.max_num_batched_tokens is None: - if usage_context == UsageContext.LLM_CLASS: - self.max_num_batched_tokens = 8192 - elif usage_context == UsageContext.OPENAI_API_SERVER: - self.max_num_batched_tokens = 2048 + default_max_num_batched_tokens = { + UsageContext.LLM_CLASS: 8192, + UsageContext.OPENAI_API_SERVER: 2048, + } + if (self.max_num_batched_tokens is None + and usage_context in default_max_num_batched_tokens): + self.max_num_batched_tokens = default_max_num_batched_tokens[ + usage_context] logger.warning( "Setting max_num_batched_tokens to %d for %s usage context.", self.max_num_batched_tokens, usage_context.value) @@ -1263,9 +1267,6 @@ def _override_v1_engine_config(self, engine_config: VllmConfig) -> None: Override the EngineConfig's configs based on the usage context for V1. """ assert envs.VLLM_USE_V1, "V1 is not enabled" - if engine_config.model_config.is_multimodal_model: - # TODO (ywang96): Enable APC by default when VLM supports it. - assert not engine_config.cache_config.enable_prefix_caching @dataclass diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 85aaaa776907f..d54cbb5c37819 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -162,6 +162,11 @@ class TokenInputs(TypedDict): Placeholder ranges for the multi-modal data. """ + multi_modal_hashes: NotRequired[List[str]] + """ + The hashes of the multi-modal data. + """ + mm_processor_kwargs: NotRequired[Dict[str, Any]] """ Optional multi-modal processor kwargs to be forwarded to the @@ -177,6 +182,7 @@ def token_inputs( prompt: Optional[str] = None, multi_modal_data: Optional["MultiModalDataDict"] = None, multi_modal_inputs: Optional["MultiModalKwargs"] = None, + multi_modal_hashes: Optional[List[str]] = None, multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None, mm_processor_kwargs: Optional[Dict[str, Any]] = None, ) -> TokenInputs: @@ -191,6 +197,8 @@ def token_inputs( inputs["multi_modal_data"] = multi_modal_data if multi_modal_inputs is not None: inputs["multi_modal_inputs"] = multi_modal_inputs + if multi_modal_hashes is not None: + inputs["multi_modal_hashes"] = multi_modal_hashes if multi_modal_placeholders is not None: inputs["multi_modal_placeholders"] = multi_modal_placeholders if mm_processor_kwargs is not None: @@ -295,6 +303,18 @@ def multi_modal_inputs(self) -> Union[Dict, "MultiModalKwargs"]: assert_never(inputs) + @cached_property + def multi_modal_hashes(self) -> List[str]: + inputs = self.inputs + + if inputs["type"] == "token": + return inputs.get("multi_modal_hashes", []) + + if inputs["type"] == "multimodal": + return inputs.get("mm_hashes", []) + + assert_never(inputs) + @cached_property def multi_modal_placeholders(self) -> "MultiModalPlaceholderDict": inputs = self.inputs diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 229a8fbdf5831..c00943a5f26d9 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -215,6 +215,9 @@ class MultiModalInputsV2(TypedDict): mm_kwargs: MultiModalKwargs """Keyword arguments to be directly passed to the model after batching.""" + mm_hashes: NotRequired[List[str]] + """The hashes of the multi-modal data.""" + mm_placeholders: MultiModalPlaceholderDict """ For each modality, information about the placeholder tokens in diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index aaa44c930e324..61a3f5fd6d841 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -4,7 +4,9 @@ from vllm.logger import init_logger from vllm.utils import cdiv from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue, - KVCacheBlock, hash_block_tokens, + KVCacheBlock, + generate_block_hash_extra_keys, + hash_block_tokens, hash_request_tokens) from vllm.v1.request import Request @@ -83,10 +85,12 @@ def get_computed_blocks(self, request: Request) -> List[KVCacheBlock]: computed_blocks = [] - # TODO(rickyx): potentially we could cache this so we don't have to - # recompute it every time. - block_hashes = hash_request_tokens(self.block_size, - request.all_token_ids) + # The block hashes for the request may already be computed + # if the request was preempted and resumed. + if not request.kv_block_hashes: + request.set_kv_block_hashes( + hash_request_tokens(self.block_size, request)) + block_hashes = request.kv_block_hashes for block_hash in block_hashes: # block_hashes is a chain of block hashes. If a block hash is not @@ -242,14 +246,16 @@ def allocate_slots( num_computed_tokens = len(computed_blocks) * self.block_size num_full_blocks = (num_computed_tokens + num_tokens) // self.block_size - self._cache_full_blocks( - request=request, - blk_start_idx=len(computed_blocks), - # The new full blocks are the full blocks that are not computed. - full_blocks=self.req_to_blocks[request.request_id] - [len(computed_blocks):num_full_blocks], - prev_block=computed_blocks[-1] if computed_blocks else None, - ) + new_full_blocks = self.req_to_blocks[ + request.request_id][len(computed_blocks):num_full_blocks] + if new_full_blocks: + self._cache_full_blocks( + request=request, + blk_start_idx=len(computed_blocks), + # The new full blocks are the full blocks that are not computed. + full_blocks=new_full_blocks, + prev_block=computed_blocks[-1] if computed_blocks else None, + ) return new_blocks @@ -376,6 +382,8 @@ def _cache_full_blocks( full_blocks: The list of blocks to update hash metadata. prev_block: The previous block in the chain. """ + num_cached_block_hashes = len(request.kv_block_hashes) + # Update the new blocks with the block hashes through the chain. prev_block_hash_value = None if prev_block is not None: @@ -387,17 +395,35 @@ def _cache_full_blocks( for i, blk in enumerate(full_blocks): blk_idx = blk_start_idx + i - block_tokens = request.all_token_ids[blk_idx * - self.block_size:(blk_idx + - 1) * - self.block_size] - assert len(block_tokens) == self.block_size, ( - f"Expected {self.block_size} tokens, got {len(block_tokens)} " - f"at {blk_idx}th block for request " - f"{request.request_id}({request})") - - # Compute the hash of the current block. - block_hash = hash_block_tokens(prev_block_hash_value, block_tokens) + if blk_idx < num_cached_block_hashes: + # The block hash may already be computed in + # "get_computed_blocks" if the tokens are not generated by + # this request (either the prompt tokens or the previously + # generated tokens with preemption). In this case we simply + # reuse the block hash. + block_hash = request.kv_block_hashes[blk_idx] + else: + # Otherwise compute the block hash and cache it in the request + # in case it will be preempted in the future. + start_token_idx = blk_idx * self.block_size + end_token_idx = (blk_idx + 1) * self.block_size + block_tokens = request.all_token_ids[ + start_token_idx:end_token_idx] + assert len(block_tokens) == self.block_size, ( + f"Expected {self.block_size} tokens, got " + f"{len(block_tokens)} at {blk_idx}th block for request " + f"{request.request_id}({request})") + + # Generate extra keys for multi-modal inputs. Note that since + # we reach to this branch only when the block is completed with + # generated tokens, we only need to consider the last mm input. + extra_keys, _ = generate_block_hash_extra_keys( + request, start_token_idx, end_token_idx, -1) + + # Compute the hash of the current block. + block_hash = hash_block_tokens(prev_block_hash_value, + block_tokens, extra_keys) + request.append_kv_block_hashes(block_hash) # Update and added the full block to the cache. blk.block_hash = block_hash diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 0ba338aa5a3d2..d80ea128c7749 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -1,20 +1,25 @@ """KV-Cache Utilities.""" from collections.abc import Sequence from dataclasses import dataclass -from typing import List, NamedTuple, Optional, Tuple +from typing import Any, List, NamedTuple, Optional, Tuple from vllm.logger import init_logger +from vllm.v1.request import Request logger = init_logger(__name__) class BlockHashType(NamedTuple): - """Hash value of a block and the token IDs in the block. - The reason we keep a tuple of token IDs is to make sure no hash - collision happens when the hash value is the same. + """Hash value of a block (int), the token IDs in the block, and extra keys. + The reason we keep a tuple of token IDs and extra keys is to make sure + no hash collision happens when the hash value is the same. """ + # Hash value of the block in an integer. hash_value: int + # Token IDs in the block. token_ids: Tuple[int, ...] + # Extra keys for the block. + extra_keys: Optional[Any] = None @dataclass @@ -159,8 +164,80 @@ def get_all_free_blocks(self) -> List[KVCacheBlock]: return ret -def hash_block_tokens(parent_block_hash: Optional[int], - curr_block_token_ids: Sequence[int]) -> BlockHashType: +def generate_block_hash_extra_keys( + request: Request, start_token_idx: int, end_token_idx: int, + start_mm_idx: int) -> Tuple[Optional[Tuple[Any, ...]], int]: + """Generate extra keys for the block hash. The extra keys can come from + the multi-modal inputs and request specific metadata (e.g., LoRA ID). + For multi-modal inputs, the extra keys are (mm_hash, start_offset) that + indicate a mm input contained in the block and its starting offset in + the block tokens. + + Args: + request: The request object. + start_token_idx: The start token index of the block. + end_token_idx: The end token index of the block. + start_mm_idx: The start multi-modal index of the block. + + Returns: + A tuple of extra keys and the next multi-modal index. + """ + + mm_positions, mm_hashes = request.mm_positions, request.mm_hashes + if not mm_positions: + return None, start_mm_idx + + if mm_positions and len(mm_positions) != len(mm_hashes): + raise ValueError( + "The number of multi-modal positions and hashes must match. This " + "is likely because you do not enable MM preprocessor hashing. " + "Please set mm_cache_preprocessor=True.") + + # Note that we assume mm_positions is sorted by offset. + # We do not need to check all mm inputs if the start token index is out of + # range. This usually happens in the late prefill phase and decoding phase. + if mm_positions[-1]["offset"] + mm_positions[-1][ + "length"] < start_token_idx: + return None, start_mm_idx + + # Support start_mm_idx == -1 to indicate the last mm input. + if start_mm_idx < 0: + assert -start_mm_idx <= len(mm_positions) + start_mm_idx = len(mm_positions) + start_mm_idx + + extra_keys = [] + curr_mm_idx = start_mm_idx + while mm_positions and curr_mm_idx < len(mm_positions): + assert mm_hashes[curr_mm_idx] is not None + offset = mm_positions[curr_mm_idx]["offset"] + length = mm_positions[curr_mm_idx]["length"] + if end_token_idx > offset: + if start_token_idx > offset + length: + # This block has passed the current mm input. + curr_mm_idx += 1 + continue + + # The block contains the current mm input. + mm_start = max(0, start_token_idx - offset) + extra_keys.append((mm_hashes[curr_mm_idx], mm_start)) + if end_token_idx >= offset + length: + # If this block contains the end of the current mm input, + # move to the next mm input as this block may also contain + # the next mm input. + curr_mm_idx += 1 + else: + # Otherwise this block is done with mm inputs. + break + else: + # This block has not reached the current mm input. + break + return tuple(extra_keys), curr_mm_idx + + +def hash_block_tokens( + parent_block_hash: Optional[int], + curr_block_token_ids: Sequence[int], + extra_keys: Optional[Tuple[Any, ...]] = None) -> BlockHashType: """Computes a hash value corresponding to the contents of a block and the contents of the preceding block(s). The hash value is used for prefix caching. We use LRU cache for this function to avoid recomputing @@ -174,27 +251,39 @@ def hash_block_tokens(parent_block_hash: Optional[int], if this is the first block. curr_block_token_ids: A list of token ids in the current block. The current block is assumed to be full. + extra_keys: Extra keys for the block. Returns: The hash value of the block and the token ids in the block. The entire tuple is used as the hash key of the block. """ return BlockHashType(hash((parent_block_hash, *curr_block_token_ids)), - tuple(curr_block_token_ids)) + tuple(curr_block_token_ids), extra_keys) def hash_request_tokens(block_size: int, - token_ids: Sequence[int]) -> List[BlockHashType]: + request: Request) -> List[BlockHashType]: """Computes hash values of a chain of blocks given a sequence of token IDs. The hash value is used for prefix caching. Args: block_size: The size of each block. - token_ids: A sequence of token ids in the request. + request: The request object. Returns: The list of computed hash values. """ + token_ids = request.all_token_ids + mm_positions, mm_hashes = request.mm_positions, request.mm_hashes + if mm_positions and len(mm_positions) != len(mm_hashes): + raise ValueError( + "The number of multi-modal positions and hashes must match.") + + # TODO: Extend this to support other features such as LoRA. + need_extra_keys = bool(mm_positions) + extra_keys = None + curr_mm_idx = 0 + ret = [] parent_block_hash_value = None for start in range(0, len(token_ids), block_size): @@ -203,8 +292,14 @@ def hash_request_tokens(block_size: int, # Do not hash the block if it is not full. if len(block_token_ids) < block_size: break + + # Add extra keys if the block is a multi-modal block. + if need_extra_keys: + extra_keys, curr_mm_idx = generate_block_hash_extra_keys( + request, start, end, curr_mm_idx) + block_hash = hash_block_tokens(parent_block_hash_value, - block_token_ids) + block_token_ids, extra_keys) ret.append(block_hash) parent_block_hash_value = block_hash.hash_value return ret diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 178532e477dae..08e7c0fd4dc9b 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -516,6 +516,7 @@ class NewRequestData: prompt_token_ids: List[int] prompt: Optional[str] mm_inputs: List["MultiModalKwargs"] + mm_hashes: List[str] mm_positions: List["PlaceholderRange"] sampling_params: SamplingParams block_ids: List[int] @@ -533,6 +534,7 @@ def from_request( prompt_token_ids=request.prompt_token_ids, prompt=request.prompt, mm_inputs=request.mm_inputs, + mm_hashes=request.mm_hashes, mm_positions=request.mm_positions, sampling_params=request.sampling_params, block_ids=block_ids, diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index b36de5f66917c..41fb4b25d45bb 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -60,9 +60,13 @@ def __init__( self.client_aborted_requests: List[str] = [] # Processor (converts Inputs --> EngineCoreRequests). - self.processor = Processor(vllm_config.model_config, - vllm_config.lora_config, self.tokenizer, - input_registry) + self.processor = Processor( + model_config=vllm_config.model_config, + cache_config=vllm_config.cache_config, + lora_config=vllm_config.lora_config, + tokenizer=self.tokenizer, + input_registry=input_registry, + ) # Detokenizer (converts EngineCoreOutputs --> RequestOutput). self.detokenizer = Detokenizer( diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 56d4dc67e4a0e..497d5db5b4c99 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -65,7 +65,8 @@ def __init__( self._last_logging_time = time.time() - self.mm_input_mapper_server = MMInputMapperServer() + self.mm_input_mapper_server = MMInputMapperServer( + vllm_config.model_config) def _initialize_kv_caches(self, cache_config: CacheConfig) -> Tuple[int, int]: @@ -98,9 +99,8 @@ def add_request(self, request: EngineCoreRequest): # MM mapper, so anything that has a hash must have a HIT cache # entry here as well. assert request.mm_inputs is not None - request.mm_inputs, request.mm_hashes = ( - self.mm_input_mapper_server.process_inputs( - request.mm_inputs, request.mm_hashes)) + request.mm_inputs = self.mm_input_mapper_server.process_inputs( + request.mm_inputs, request.mm_hashes) req = Request.from_engine_core_request(request) diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 15dedbd0f9529..bea8c5502f612 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -55,9 +55,12 @@ def __init__( self.tokenizer.ping() # Processor (convert Inputs --> EngineCoreRequests) - self.processor = Processor(vllm_config.model_config, - vllm_config.lora_config, self.tokenizer, - input_registry, mm_registry) + self.processor = Processor(model_config=vllm_config.model_config, + cache_config=vllm_config.cache_config, + lora_config=vllm_config.lora_config, + tokenizer=self.tokenizer, + input_registry=input_registry, + mm_registry=mm_registry) # Detokenizer (converts EngineCoreOutputs --> RequestOutput) self.detokenizer = Detokenizer( diff --git a/vllm/v1/engine/mm_input_mapper.py b/vllm/v1/engine/mm_input_mapper.py index 6cdeba6f3f71e..e53ba092ede04 100644 --- a/vllm/v1/engine/mm_input_mapper.py +++ b/vllm/v1/engine/mm_input_mapper.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional import PIL from blake3 import blake3 @@ -42,6 +42,8 @@ def __init__( model_config) self.mm_registry.init_mm_limits_per_prompt(model_config) + # Init cache + self.use_cache = model_config.mm_cache_preprocessor self.mm_cache = LRUDictCache[str, MultiModalKwargs](MM_CACHE_SIZE) # DEBUG: Set to None to disable @@ -61,7 +63,7 @@ def process_inputs( mm_hashes: Optional[List[str]], mm_processor_kwargs: Optional[Dict[str, Any]], precomputed_mm_inputs: Optional[List[MultiModalKwargs]], - ) -> Tuple[List[MultiModalKwargs], Optional[List[str]]]: + ) -> List[MultiModalKwargs]: if precomputed_mm_inputs is None: image_inputs = mm_data["image"] if not isinstance(image_inputs, list): @@ -70,26 +72,21 @@ def process_inputs( else: num_inputs = len(precomputed_mm_inputs) - # Check if hash is enabled - use_hash = mm_hashes is not None - if use_hash: + # Sanity + if self.use_cache: assert mm_hashes is not None - assert num_inputs == len( - mm_hashes), "num_inputs = {} len(mm_hashes) = {}".format( - num_inputs, len(mm_hashes)) + assert num_inputs == len(mm_hashes) # Process each image input separately, so that later we can schedule # them in a fine-grained manner. # Apply caching (if enabled) and reuse precomputed inputs (if provided) - ret_hashes: Optional[List[str]] = [] if use_hash else None ret_inputs: List[MultiModalKwargs] = [] for input_id in range(num_inputs): if self.mm_debug_cache_hit_ratio_steps is not None: self.cache_hit_ratio(self.mm_debug_cache_hit_ratio_steps) - mm_hash = None mm_input = None - if use_hash: + if self.use_cache: assert mm_hashes is not None mm_hash = mm_hashes[input_id] mm_input = self.mm_cache.get(mm_hash) @@ -106,7 +103,7 @@ def process_inputs( mm_processor_kwargs=mm_processor_kwargs, ) - if use_hash: + if self.use_cache: # Add to cache assert mm_hash is not None self.mm_cache.put(mm_hash, mm_input) @@ -114,18 +111,15 @@ def process_inputs( self.mm_cache_hits += 1 mm_input = None # Avoids sending mm_input to Server - if use_hash: - assert mm_hash is not None - assert ret_hashes is not None - ret_hashes.append(mm_hash) ret_inputs.append(mm_input) - return ret_inputs, ret_hashes + return ret_inputs class MMInputMapperServer: - def __init__(self, ): + def __init__(self, model_config): + self.use_cache = model_config.mm_cache_preprocessor self.mm_cache = LRUDictCache[str, MultiModalKwargs](MM_CACHE_SIZE) def process_inputs( @@ -135,6 +129,9 @@ def process_inputs( ) -> List[MultiModalKwargs]: assert len(mm_inputs) == len(mm_hashes) + if not self.use_cache: + return mm_inputs + full_mm_inputs = [] for mm_input, mm_hash in zip(mm_inputs, mm_hashes): assert mm_hash is not None diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 679bf8e25e9ca..732757d6b0ac2 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -1,7 +1,7 @@ import time from typing import Any, Dict, Mapping, Optional, Tuple, Union -from vllm.config import LoRAConfig, ModelConfig +from vllm.config import CacheConfig, LoRAConfig, ModelConfig from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs, PromptType, SingletonInputsAdapter) from vllm.inputs.parse import is_encoder_decoder_inputs @@ -23,6 +23,7 @@ class Processor: def __init__( self, model_config: ModelConfig, + cache_config: CacheConfig, lora_config: Optional[LoRAConfig], tokenizer: BaseTokenizerGroup, input_registry: InputRegistry = INPUT_REGISTRY, @@ -45,8 +46,9 @@ def __init__( self.mm_input_mapper_client = MMInputMapperClient(model_config) # Multi-modal hasher (for images) - self.mm_hasher = MMHasher( - ) if model_config.mm_cache_preprocessor else None + self.use_hash = model_config.mm_cache_preprocessor or \ + cache_config.enable_prefix_caching + self.mm_hasher = MMHasher() # TODO: run in an ThreadpoolExecutor or BackgroundProcess. # This ideally should releases the GIL, so we should not block the @@ -77,7 +79,7 @@ def process_inputs( # Compute MM hashes (if enabled) mm_hashes = None - if self.mm_hasher is not None: + if self.use_hash: mm_hashes = self.mm_hasher.hash(prompt) # Process inputs. @@ -118,7 +120,7 @@ def process_inputs( # Apply MM mapper mm_inputs = None if len(decoder_inputs.multi_modal_data) > 0: - mm_inputs, mm_hashes = self.mm_input_mapper_client.process_inputs( + mm_inputs = self.mm_input_mapper_client.process_inputs( decoder_inputs.multi_modal_data, mm_hashes, decoder_inputs.mm_processor_kwargs, precomputed_mm_inputs) diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 1737d096e811d..f4783ae366ef0 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -1,5 +1,5 @@ import enum -from typing import List, Optional, Union +from typing import TYPE_CHECKING, List, Optional, Union from vllm.inputs import DecoderOnlyInputs, SingletonInputsAdapter, token_inputs from vllm.lora.request import LoRARequest @@ -9,6 +9,9 @@ from vllm.v1.engine import EngineCoreRequest from vllm.v1.utils import ConstantList +if TYPE_CHECKING: + from vllm.v1.core.kv_cache_utils import BlockHashType + class Request: @@ -45,6 +48,7 @@ def __init__( self._all_token_ids: List[int] = self.prompt_token_ids.copy() self.num_computed_tokens = 0 + # Multi-modal input metadata. mm_positions = self.inputs.multi_modal_placeholders if mm_positions: # FIXME(woosuk): Support other modalities. @@ -56,6 +60,12 @@ def __init__( if self.inputs.multi_modal_inputs: self.mm_inputs = self.inputs.multi_modal_inputs + self.mm_hashes: List[str] = self.inputs.multi_modal_hashes + + # Cache the computed kv block hashes of the request to avoid + # recomputing. + self._kv_block_hashes: List[BlockHashType] = [] + @classmethod def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": return cls( @@ -65,6 +75,7 @@ def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": prompt=request.prompt, multi_modal_data=None, multi_modal_inputs=request.mm_inputs, + multi_modal_hashes=request.mm_hashes, multi_modal_placeholders=request.mm_placeholders, mm_processor_kwargs=None, ), @@ -121,6 +132,17 @@ def get_num_encoder_tokens(self, input_id: int) -> int: num_tokens = self.mm_positions[input_id]["length"] return num_tokens + @property + def kv_block_hashes(self) -> ConstantList["BlockHashType"]: + # Prevent directly appending to the kv_block_hashes. + return ConstantList(self._kv_block_hashes) + + def set_kv_block_hashes(self, value: List["BlockHashType"]) -> None: + self._kv_block_hashes = value + + def append_kv_block_hashes(self, block_hash: "BlockHashType") -> None: + self._kv_block_hashes.append(block_hash) + class RequestStatus(enum.IntEnum): """Status of a request.""" From 866fa4550d572f4ff3521ccf503e0df2e76591a1 Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Wed, 18 Dec 2024 01:39:07 +0100 Subject: [PATCH 017/200] [Bugfix] Restore support for larger block sizes (#11259) Signed-off-by: Konrad Zawora --- vllm/config.py | 4 ++++ vllm/engine/arg_utils.py | 6 ++++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 9ecd3e72afa9f..307cf9c8d5b2a 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -917,6 +917,10 @@ def _verify_args(self) -> None: raise ValueError( "GPU memory utilization must be less than 1.0. Got " f"{self.gpu_memory_utilization}.") + if (current_platform.is_cuda() and self.block_size is not None + and self.block_size > 32): + raise ValueError("CUDA Paged Attention kernel only supports " + f"block sizes up to 32. Got {self.block_size}.") def _verify_cache_dtype(self) -> None: if self.cache_dtype == "auto": diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 674577f23eba6..64cc4592c2861 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -424,10 +424,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser.add_argument('--block-size', type=int, default=EngineArgs.block_size, - choices=[8, 16, 32], + choices=[8, 16, 32, 64, 128], help='Token block size for contiguous chunks of ' 'tokens. This is ignored on neuron devices and ' - 'set to max-model-len') + 'set to max-model-len. On CUDA devices, ' + 'only block sizes up to 32 are supported. ' + 'On HPU devices, block size defaults to 128.') parser.add_argument( "--enable-prefix-caching", From 8b79f9e107fd4214187bf65485b3ea1bb3191a46 Mon Sep 17 00:00:00 2001 From: Wallas Henrique Date: Wed, 18 Dec 2024 03:34:08 -0300 Subject: [PATCH 018/200] [Bugfix] Fix guided decoding with tokenizer mode mistral (#11046) --- .buildkite/test-pipeline.yaml | 6 +- requirements-common.txt | 3 +- .../model_executor/test_guided_processors.py | 54 ++++++++- .../decoder_only/language/test_mistral.py | 86 ++++++++++++- .../guided_decoding/xgrammar_decoding.py | 113 +++++++++++------- vllm/transformers_utils/tokenizer.py | 2 +- vllm/transformers_utils/tokenizers/mistral.py | 5 +- 7 files changed, 217 insertions(+), 52 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 44f47fac1c1b3..b563c96343f92 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -224,8 +224,12 @@ steps: mirror_hardwares: [amd] source_file_dependencies: - vllm/model_executor/layers + - vllm/model_executor/guided_decoding - tests/test_logits_processor - command: pytest -v -s test_logits_processor.py + - tests/model_executor/test_guided_processors + commands: + - pytest -v -s test_logits_processor.py + - pytest -v -s model_executor/test_guided_processors.py - label: Speculative decoding tests # 30min source_file_dependencies: diff --git a/requirements-common.txt b/requirements-common.txt index bd2b4b7a01668..1c935303c8d79 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -14,12 +14,13 @@ aiohttp openai >= 1.45.0 # Ensure modern openai package (ensure types module present and max_completion_tokens field support) uvicorn[standard] pydantic >= 2.9 # Required for fastapi >= 0.113.0 -pillow # Required for image processing prometheus_client >= 0.18.0 +pillow # Required for image processing prometheus-fastapi-instrumentator >= 7.0.0 tiktoken >= 0.6.0 # Required for DBRX tokenizer lm-format-enforcer >= 0.10.9, < 0.11 outlines == 0.1.11 +lark == 1.2.2 xgrammar >= 0.1.6; platform_machine == "x86_64" typing_extensions >= 4.10 filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317 diff --git a/tests/model_executor/test_guided_processors.py b/tests/model_executor/test_guided_processors.py index 9f4d81b583141..3334c0df149b5 100644 --- a/tests/model_executor/test_guided_processors.py +++ b/tests/model_executor/test_guided_processors.py @@ -1,13 +1,19 @@ +import pickle + import pytest import torch from transformers import AutoTokenizer +from vllm.config import ModelConfig from vllm.model_executor.guided_decoding import ( - get_guided_decoding_logits_processor) + get_guided_decoding_logits_processor, + get_local_guided_decoding_logits_processor) from vllm.model_executor.guided_decoding.outlines_logits_processors import ( JSONLogitsProcessor, RegexLogitsProcessor) from vllm.sampling_params import GuidedDecodingParams +MODEL_NAME = 'HuggingFaceH4/zephyr-7b-beta' + def test_guided_logits_processors(sample_regex, sample_json_schema): """Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor.""" @@ -38,14 +44,29 @@ def test_guided_logits_processors(sample_regex, sample_json_schema): @pytest.mark.asyncio @pytest.mark.parametrize("backend", ["outlines", "lm-format-enforcer", "xgrammar"]) -async def test_guided_logits_processor_black_box(backend: str, sample_regex, +@pytest.mark.parametrize("is_local", [True, False]) +async def test_guided_logits_processor_black_box(backend: str, is_local: bool, + sample_regex, sample_json_schema): - tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta') + + config = ModelConfig( + MODEL_NAME, + task="generate", + tokenizer=MODEL_NAME, + tokenizer_mode="auto", + trust_remote_code=False, + seed=0, + dtype="bfloat16", + ) + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) token_ids = tokenizer.encode( f"Give an example IPv4 address with this regex: {sample_regex}") regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend) - regex_lp = await get_guided_decoding_logits_processor( - regex_request, tokenizer) + + regex_lp = get_local_guided_decoding_logits_processor( + regex_request, tokenizer, config) if is_local else \ + await get_guided_decoding_logits_processor( + regex_request, tokenizer, config) assert regex_lp is not None tensor = torch.rand(32000) original_tensor = torch.clone(tensor) @@ -59,7 +80,7 @@ async def test_guided_logits_processor_black_box(backend: str, sample_regex, json_request = GuidedDecodingParams(json=sample_json_schema, backend=backend) json_lp = await get_guided_decoding_logits_processor( - json_request, tokenizer) + json_request, tokenizer, config) assert json_lp is not None tensor = torch.rand(32000) original_tensor = torch.clone(tensor) @@ -84,3 +105,24 @@ def test_multiple_guided_options_not_allowed(sample_json_schema, sample_regex): with pytest.raises(ValueError, match="You can only use one kind of guided"): GuidedDecodingParams(json=sample_json_schema, grammar="test grammar") + + +def test_pickle_xgrammar_tokenizer_data(): + + # TODO: move to another test file for xgrammar + try: + import xgrammar as xgr + except ImportError: + pytest.skip("Could not import xgrammar to run test") + + from vllm.model_executor.guided_decoding.xgrammar_decoding import ( + TokenizerData) + tokenizer_data = TokenizerData(vocab_type=xgr.VocabType.RAW) + pickled = pickle.dumps(tokenizer_data) + + assert pickled is not None + + depickled: TokenizerData = pickle.loads(pickled) + + assert depickled is not None + assert depickled.vocab_type == xgr.VocabType.RAW diff --git a/tests/models/decoder_only/language/test_mistral.py b/tests/models/decoder_only/language/test_mistral.py index 99b5d5694f9f7..bdc1571784b5d 100644 --- a/tests/models/decoder_only/language/test_mistral.py +++ b/tests/models/decoder_only/language/test_mistral.py @@ -3,17 +3,20 @@ Run `pytest tests/models/test_mistral.py`. """ import copy +import json +import jsonschema +import jsonschema.exceptions import pytest -from vllm import SamplingParams from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( # noqa MistralToolParser) +from vllm.sampling_params import GuidedDecodingParams, SamplingParams from ...utils import check_logprobs_close MODELS = [ - "mistralai/Mistral-7B-Instruct-v0.1", + "mistralai/Mistral-7B-Instruct-v0.3", ] MISTRAL_FORMAT_MODELS = [ @@ -126,6 +129,45 @@ } ] +SAMPLE_JSON_SCHEMA = { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "age": { + "type": "integer" + }, + "skills": { + "type": "array", + "items": { + "type": "string", + "maxLength": 10 + }, + "minItems": 3 + }, + "work_history": { + "type": "array", + "items": { + "type": "object", + "properties": { + "company": { + "type": "string" + }, + "duration": { + "type": "number" + }, + "position": { + "type": "string" + } + }, + "required": ["company", "position"] + } + } + }, + "required": ["name", "age", "skills", "work_history"] +} + @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["bfloat16"]) @@ -251,3 +293,43 @@ def test_mistral_function_calling( assert parsed_message.tool_calls[ 0].function.arguments == '{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}' # noqa assert parsed_message.content is None + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("guided_backend", + ["outlines", "lm-format-enforcer", "xgrammar"]) +def test_mistral_guided_decoding( + vllm_runner, + model: str, + guided_backend: str, +) -> None: + with vllm_runner(model, dtype='bfloat16', + tokenizer_mode="mistral") as vllm_model: + + guided_decoding = GuidedDecodingParams(json=SAMPLE_JSON_SCHEMA, + backend=guided_backend) + params = SamplingParams(max_tokens=512, + temperature=0.7, + guided_decoding=guided_decoding) + + messages = [{ + "role": "system", + "content": "you are a helpful assistant" + }, { + "role": + "user", + "content": + f"Give an example JSON for an employee profile that " + f"fits this schema: {SAMPLE_JSON_SCHEMA}" + }] + outputs = vllm_model.model.chat(messages, sampling_params=params) + + generated_text = outputs[0].outputs[0].text + json_response = json.loads(generated_text) + assert outputs is not None + + try: + jsonschema.validate(instance=json_response, + schema=SAMPLE_JSON_SCHEMA) + except jsonschema.exceptions.ValidationError: + pytest.fail("Generated response is not valid with JSON schema") diff --git a/vllm/model_executor/guided_decoding/xgrammar_decoding.py b/vllm/model_executor/guided_decoding/xgrammar_decoding.py index fc45e37cf6f06..5b97f03257502 100644 --- a/vllm/model_executor/guided_decoding/xgrammar_decoding.py +++ b/vllm/model_executor/guided_decoding/xgrammar_decoding.py @@ -3,7 +3,7 @@ import json from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, NamedTuple +from typing import TYPE_CHECKING, Any import torch from transformers import PreTrainedTokenizerFast @@ -16,6 +16,7 @@ from vllm.model_executor.guided_decoding.xgrammar_utils import ( convert_lark_to_gbnf, grammar_is_likely_lark) +from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer if TYPE_CHECKING: from transformers import PreTrainedTokenizer @@ -37,11 +38,21 @@ def get_local_xgrammar_guided_decoding_logits_processor( return XGrammarLogitsProcessor(config) -class TokenizerData(NamedTuple): +@dataclass(frozen=True) +class TokenizerData: """Immutable container for cached tokenizer data.""" - encoded_vocab: list[str] - stop_token_ids: list[int] | None - backend_str: str + encoded_vocab: list[str] = field(default_factory=list) + stop_token_ids: list[int] | None = None + # These fields are mutually exclusive: `backend_str` is used to create a + # TokenizeInfo with `TokenizerInfo.from_huggingface` while `vocab_type` is + # used within the constructor of TokenizeInfo + backend_str: str | None = None + vocab_type: xgr.VocabType | None = None + + def __post_init__(self): + # Check for mutual exclusive + assert not (self.backend_str and self.vocab_type), \ + "backend_str and vocab_type are mutual exclusive" class TokenizerDataCache: @@ -68,18 +79,27 @@ def get_tokenizer_data(cls, "get_vocab method.") from e stop_token_ids = None - backend_str = xgr.VocabType.RAW + backend_str = "" + vocab_type = xgr.VocabType.RAW + + if stop_token_ids is None and hasattr( + tokenizer, + "eos_token_id") and tokenizer.eos_token_id is not None: + stop_token_ids = [tokenizer.eos_token_id] + if isinstance(tokenizer, PreTrainedTokenizerFast): backend_str = tokenizer.backend_tokenizer.to_str() - if stop_token_ids is None and hasattr( - tokenizer, - "eos_token_id") and tokenizer.eos_token_id is not None: - stop_token_ids = [tokenizer.eos_token_id] + vocab_type = None + + elif isinstance(tokenizer, MistralTokenizer): + # REF: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501 + vocab_type = xgr.VocabType.BYTE_FALLBACK cls._cache[tokenizer_hash] = TokenizerData( encoded_vocab=encoded_vocab, stop_token_ids=stop_token_ids, - backend_str=backend_str) + backend_str=backend_str, + vocab_type=vocab_type) return cls._cache[tokenizer_hash] @@ -98,11 +118,30 @@ def get_compiler(cls, config: GrammarConfig) -> xgr.GrammarCompiler: cache_key = str(config.tokenizer_hash) if cache_key not in cls._cache: - assert config.encoded_vocab is not None - tokenizer_info = xgr.TokenizerInfo._create_from_handle( - xgr_core.TokenizerInfo.from_huggingface( - config.encoded_vocab, config.backend_str, - config.vocab_size, config.stop_token_ids)) + assert config.tokenizer_data is not None + assert config.tokenizer_data.encoded_vocab is not None + + config_data = config.tokenizer_data + + # In TokenizerDataCache.get_tokenizer_data, a serializable + # tokenizer_data is created and cached. This data is used to build + # a tokenizer_info and create an xgrammar compiler. + # - If tokenizer_data has backend_str set, use + # xgr_core.TokenizerInfo.from_huggingface (a C++ bind). + # - Otherwise, use the default constructor with vocab_type. + # - xgr_core.TokenizerInfo.from_huggingface != + # xgr.TokenizerInfo.from_huggingface. + if config_data.backend_str: + tokenizer_info = xgr.TokenizerInfo._create_from_handle( + xgr_core.TokenizerInfo.from_huggingface( + config_data.encoded_vocab, config_data.backend_str, + config.vocab_size, config_data.stop_token_ids)) + else: + tokenizer_info = xgr.TokenizerInfo( + config_data.encoded_vocab, + config_data.vocab_type, + vocab_size=config.vocab_size, + stop_token_ids=config_data.stop_token_ids) cls._cache[cache_key] = xgr.GrammarCompiler( tokenizer_info, max_threads=config.max_threads) @@ -118,10 +157,7 @@ class GrammarConfig: grammar_str: str | None = None json_object: bool | None = None max_threads: int = 8 - # Only populated if tokenizer_hash not in cache - encoded_vocab: list[str] | None = None - stop_token_ids: list[int] | None = None - backend_str: str | None = None + tokenizer_data: TokenizerData | None = None @classmethod def from_guided_params(cls, @@ -132,9 +168,6 @@ def from_guided_params(cls, tokenizer_hash = hash(tokenizer) tokenizer_data = TokenizerDataCache.get_tokenizer_data(tokenizer) - encoded_vocab = tokenizer_data.encoded_vocab - stop_token_ids = tokenizer_data.stop_token_ids - backend_str = tokenizer_data.backend_str if guided_params.json: if not isinstance(guided_params.json, str): @@ -152,11 +185,9 @@ def from_guided_params(cls, return cls(json_str=json_str, vocab_size=model_config.hf_text_config.vocab_size, - encoded_vocab=encoded_vocab, - stop_token_ids=stop_token_ids, - backend_str=backend_str, tokenizer_hash=tokenizer_hash, - max_threads=max_threads) + max_threads=max_threads, + tokenizer_data=tokenizer_data) elif guided_params.grammar: # XGrammar only supports GBNF grammars, so we must convert Lark if grammar_is_likely_lark(guided_params.grammar): @@ -181,19 +212,17 @@ def from_guided_params(cls, return cls(grammar_str=grammar_str, vocab_size=model_config.hf_text_config.vocab_size, - encoded_vocab=encoded_vocab, - stop_token_ids=stop_token_ids, - backend_str=backend_str, tokenizer_hash=tokenizer_hash, - max_threads=max_threads) + max_threads=max_threads, + tokenizer_data=tokenizer_data) elif guided_params.json_object: - return cls(json_object=True, - vocab_size=model_config.hf_text_config.vocab_size, - encoded_vocab=encoded_vocab, - stop_token_ids=stop_token_ids, - backend_str=backend_str, - tokenizer_hash=tokenizer_hash, - max_threads=max_threads) + return cls( + json_object=True, + vocab_size=model_config.hf_text_config.vocab_size, + tokenizer_hash=tokenizer_hash, + max_threads=max_threads, + tokenizer_data=tokenizer_data, + ) else: raise ValueError( "Currently only support JSON and EBNF grammar mode for xgrammar" @@ -269,10 +298,14 @@ def __call__(self, input_ids: list[int], # fill_next_token_bitmask so we move it to the device of scores device_type = scores.device.type if device_type != "cuda": - scores = scores.to("cpu") + scores = scores.to("cpu").unsqueeze(0) + + # Note: In this method, if the tensors have different dimensions + # on CPU device fails, but on GPU it runs without error. Hence the + # unsqueeze above for scores, to match the token bitmask shape xgr.apply_token_bitmask_inplace(scores, self.token_bitmask.to(scores.device)) if device_type != "cuda": - scores = scores.to(device_type) + scores = scores.to(device_type).squeeze() return scores diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index 54f9f895fe541..e6701f4c4b835 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -132,7 +132,7 @@ def get_tokenizer( if is_from_mistral_org and tokenizer_mode != "mistral": warnings.warn( 'It is strongly recommended to run mistral models with ' - '`--tokenizer_mode "mistral"` to ensure correct ' + '`--tokenizer-mode "mistral"` to ensure correct ' 'encoding and decoding.', FutureWarning, stacklevel=2) diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py index 83b3c37d6f04c..17d722e3d88fe 100644 --- a/vllm/transformers_utils/tokenizers/mistral.py +++ b/vllm/transformers_utils/tokenizers/mistral.py @@ -314,12 +314,15 @@ def _token_to_id(t: str): if regular_tokens: decoded_list.append( - self.decode(regular_tokens)) # type: ignore + self.tokenizer.decode(regular_tokens)) # type: ignore decoded = ''.join(decoded_list) return decoded + # WARN: Outlines logits processors can overwrite this method. + # See: guided_decoding/outlines_logits_processors.py::_adapt_tokenizer + # for more. def decode(self, ids: Union[List[int], int], skip_special_tokens: bool = True) -> str: From f04e407e6b6b9ce65c16cffda836f05c2ad32682 Mon Sep 17 00:00:00 2001 From: Yan Ma Date: Wed, 18 Dec 2024 14:34:23 +0800 Subject: [PATCH 019/200] [MISC][XPU]update ipex link for CI fix (#11278) --- requirements-xpu.txt | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/requirements-xpu.txt b/requirements-xpu.txt index e41295792283f..42c6c321d040c 100644 --- a/requirements-xpu.txt +++ b/requirements-xpu.txt @@ -9,8 +9,8 @@ setuptools-scm>=8 wheel jinja2 -torch @ https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/torch-2.5.0a0%2Bgite84e33f-cp310-cp310-linux_x86_64.whl -intel-extension-for-pytorch @ https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/intel_extension_for_pytorch-2.5.10%2Bgit9d489a8-cp310-cp310-linux_x86_64.whl -oneccl_bind_pt @ https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/oneccl_bind_pt-2.5.0%2Bxpu-cp310-cp310-linux_x86_64.whl +torch @ https://intel-optimized-pytorch.s3.cn-north-1.amazonaws.com.cn/ipex_dev/xpu/torch-2.5.0a0%2Bgite84e33f-cp310-cp310-linux_x86_64.whl +intel-extension-for-pytorch @ https://intel-optimized-pytorch.s3.cn-north-1.amazonaws.com.cn/ipex_dev/xpu/intel_extension_for_pytorch-2.5.10%2Bgit9d489a8-cp310-cp310-linux_x86_64.whl +oneccl_bind_pt @ https://intel-optimized-pytorch.s3.cn-north-1.amazonaws.com.cn/ipex_dev/xpu/oneccl_bind_pt-2.5.0%2Bxpu-cp310-cp310-linux_x86_64.whl triton-xpu == 3.0.0b1 From 60508ffda91c22e4cde3b18f149d222211db8886 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Wed, 18 Dec 2024 09:57:16 -0500 Subject: [PATCH 020/200] [Kernel]: Cutlass 2:4 Sparsity + FP8/Int8 Quant Support (#10995) Co-authored-by: Faraz Shahsavan Co-authored-by: ilmarkov Co-authored-by: Rahul Tuli Co-authored-by: rshaw@neuralmagic.com --- CMakeLists.txt | 26 +- .../cutlass_benchmarks/sparse_benchmarks.py | 384 ++++++++++++++ benchmarks/cutlass_benchmarks/utils.py | 96 ++++ .../cutlass_benchmarks/w8a8_benchmarks.py | 28 +- .../cutlass_benchmarks/weight_shapes.py | 2 +- csrc/core/math.hpp | 7 + csrc/cutlass_extensions/common.cpp | 11 + csrc/cutlass_extensions/common.hpp | 35 ++ .../epilogue/scaled_mm_epilogues_c3x.hpp | 4 +- csrc/ops.h | 9 + csrc/quantization/cutlass_w8a8/common.hpp | 27 - .../cutlass_w8a8/scaled_mm_c2x.cuh | 3 +- .../cutlass_w8a8/scaled_mm_c3x.cu | 3 +- .../cutlass_w8a8/scaled_mm_entry.cu | 12 +- csrc/sparse/cutlass/sparse_compressor_c3x.cu | 163 ++++++ .../sparse/cutlass/sparse_compressor_entry.cu | 42 ++ csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu | 303 +++++++++++ csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh | 496 ++++++++++++++++++ csrc/sparse/cutlass/sparse_scaled_mm_entry.cu | 59 +++ csrc/torch_bindings.cpp | 15 + pyproject.toml | 2 +- tests/kernels/test_semi_structured.py | 131 +++++ tests/quantization/test_compressed_tensors.py | 103 +++- tests/weight_loading/models.txt | 2 + .../run_model_weight_loading_test.sh | 4 + tests/weight_loading/test_weight_loading.py | 7 + vllm/_custom_ops.py | 103 ++++ .../compressed_tensors/compressed_tensors.py | 187 ++++++- .../compressed_tensors/schemes/__init__.py | 15 +- .../schemes/compressed_tensors_24.py | 203 +++++++ 30 files changed, 2365 insertions(+), 117 deletions(-) create mode 100644 benchmarks/cutlass_benchmarks/sparse_benchmarks.py create mode 100644 benchmarks/cutlass_benchmarks/utils.py create mode 100644 csrc/core/math.hpp create mode 100644 csrc/cutlass_extensions/common.cpp create mode 100644 csrc/cutlass_extensions/common.hpp delete mode 100644 csrc/quantization/cutlass_w8a8/common.hpp create mode 100644 csrc/sparse/cutlass/sparse_compressor_c3x.cu create mode 100644 csrc/sparse/cutlass/sparse_compressor_entry.cu create mode 100644 csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu create mode 100644 csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh create mode 100644 csrc/sparse/cutlass/sparse_scaled_mm_entry.cu create mode 100644 tests/kernels/test_semi_structured.py create mode 100644 vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py diff --git a/CMakeLists.txt b/CMakeLists.txt index bf19b3d227171..51b49a18dddf2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -206,7 +206,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library") # Set CUTLASS_REVISION manually -- its revision detection doesn't work in this case. - set(CUTLASS_REVISION "v3.5.1" CACHE STRING "CUTLASS revision to use") + set(CUTLASS_REVISION "v3.6.0" CACHE STRING "CUTLASS revision to use") # Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR}) @@ -223,13 +223,13 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") FetchContent_Declare( cutlass GIT_REPOSITORY https://github.com/nvidia/cutlass.git - GIT_TAG v3.5.1 + GIT_TAG 8aa95dbb888be6d81c6fbf7169718c5244b53227 GIT_PROGRESS TRUE # Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history. # Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags. # So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE - GIT_SHALLOW TRUE + GIT_SHALLOW FALSE ) endif() FetchContent_MakeAvailable(cutlass) @@ -241,7 +241,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/quantization/awq/gemm_kernels.cu" "csrc/custom_all_reduce.cu" "csrc/permute_cols.cu" - "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu") + "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu" + "csrc/sparse/cutlass/sparse_scaled_mm_entry.cu" + "csrc/sparse/cutlass/sparse_compressor_entry.cu" + "csrc/cutlass_extensions/common.cpp") set_gencode_flags_for_srcs( SRCS "${VLLM_EXT_SRC}" @@ -271,11 +274,14 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") endif() # - # The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require + # The cutlass_scaled_mm cutlass_scaled_sparse_mm, and cutlass_compressor kernels + # For Hopper (c3x, i.e. CUTLASS 3.x) require # CUDA 12.0 or later (and only work on Hopper, 9.0/9.0a for now). cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0;9.0a" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS) - set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu") + set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu" + "csrc/sparse/cutlass/sparse_compressor_c3x.cu" + "csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" CUDA_ARCHS "${SCALED_MM_3X_ARCHS}") @@ -284,12 +290,12 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") message(STATUS "Building scaled_mm_c3x for archs: ${SCALED_MM_3X_ARCHS}") else() if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS) - message(STATUS "Not building scaled_mm_c3x as CUDA Compiler version is " + message(STATUS "Not building cutlass_c3x kernels as CUDA Compiler version is " "not >= 12.0, we recommend upgrading to CUDA 12.0 or " - "later if you intend on running FP8 quantized models on " + "later if you intend on running FP8 sparse or quantized models on " "Hopper.") else() - message(STATUS "Not building scaled_mm_c3x as no compatible archs found " + message(STATUS "Not building cutlass_c3x as no compatible archs found " "in CUDA target architectures") endif() @@ -404,7 +410,7 @@ define_gpu_extension_target( SOURCES ${VLLM_EXT_SRC} COMPILE_FLAGS ${VLLM_GPU_FLAGS} ARCHITECTURES ${VLLM_GPU_ARCHES} - INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR} + INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR};${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} USE_SABI 3 WITH_SOABI) diff --git a/benchmarks/cutlass_benchmarks/sparse_benchmarks.py b/benchmarks/cutlass_benchmarks/sparse_benchmarks.py new file mode 100644 index 0000000000000..3d1c5e392f9e2 --- /dev/null +++ b/benchmarks/cutlass_benchmarks/sparse_benchmarks.py @@ -0,0 +1,384 @@ +import argparse +import copy +import itertools +import pickle as pkl +import time +from typing import Callable, Iterable, List, Tuple + +import torch +import torch.utils.benchmark as TBenchmark +from torch.utils.benchmark import Measurement as TMeasurement +from utils import make_rand_sparse_tensors +from weight_shapes import WEIGHT_SHAPES + +from vllm import _custom_ops as ops +from vllm.utils import FlexibleArgumentParser + +DEFAULT_MODELS = list(WEIGHT_SHAPES.keys()) +DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512] +DEFAULT_TP_SIZES = [1] + + +# bench +def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args, + **kwargs) -> TMeasurement: + min_run_time = 1 + + globals = { + "args": args, + "kwargs": kwargs, + "fn": fn, + } + return TBenchmark.Timer( + stmt="fn(*args, **kwargs)", + globals=globals, + label=label, + sub_label=sub_label, + description=description, + ).blocked_autorange(min_run_time=min_run_time) + + +def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str, + sub_label: str) -> Iterable[TMeasurement]: + assert dtype == torch.int8 + b_compressed, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k) + scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) + scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) + bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) + + out = ops.cutlass_scaled_sparse_mm(a, b_compressed, e, scale_a, scale_b, + torch.bfloat16) + out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16) + + if not torch.allclose(out, out_ref): + print("Incorrect results") + print(out) + print(out_ref) + else: + print("Correct results") + + timers = [] + # pytorch impl - bfloat16 + timers.append( + bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales", + torch.mm, a.to(dtype=torch.bfloat16), + b.to(dtype=torch.bfloat16))) + + # pytorch impl - float16 + timers.append( + bench_fn(label, sub_label, + "pytorch_fp16_fp16_fp16_matmul-no-scales", torch.mm, + a.to(dtype=torch.float16), b.to(dtype=torch.float16))) + + # cutlass impl + timers.append( + bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm", + ops.cutlass_scaled_mm, a, b, scale_a, scale_b, + torch.bfloat16)) + + # cutlass with bias + timers.append( + bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_bias", + ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.bfloat16, + bias)) + + # cutlass sparse impl + timers.append( + bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_sparse_mm", + ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, + scale_b, torch.bfloat16)) + + # cutlass sparse with bias + timers.append( + bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_sparse_mm_bias", + ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, + scale_b, torch.bfloat16, bias)) + + return timers + + +def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str, + sub_label: str) -> Iterable[TMeasurement]: + assert dtype == torch.float8_e4m3fn + b_compressed, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, + k) + scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) + scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) + bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) + + out = ops.cutlass_scaled_sparse_mm(a, b_compressed, e, scale_a, scale_b, + torch.bfloat16) + out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16) + + if not torch.allclose(out, out_ref): + print("Incorrect results") + print(out) + print(out_ref) + else: + print("Correct results") + + timers = [] + + # pytorch impl w. bf16 + timers.append( + bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales", + torch.mm, a.to(dtype=torch.bfloat16, device="cuda"), + b.to(dtype=torch.bfloat16, device="cuda"))) + + # pytorch impl: bf16 output, without fp8 fast accum + timers.append( + bench_fn(label, + sub_label, + "pytorch_fp8_fp8_bf16_scaled_mm", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.bfloat16)) + + # pytorch impl: bf16 output, with fp8 fast accum + timers.append( + bench_fn(label, + sub_label, + "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.bfloat16, + use_fast_accum=True)) + + # pytorch impl: fp16 output, without fp8 fast accum + timers.append( + bench_fn(label, + sub_label, + "pytorch_fp8_fp8_fp16_scaled_mm", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.float16)) + + # pytorch impl: fp16 output, with fp8 fast accum + timers.append( + bench_fn(label, + sub_label, + "pytorch_fp8_fp8_fp16_scaled_mm_fast_accum", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.float16, + use_fast_accum=True)) + + # cutlass impl: bf16 output + timers.append( + bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_mm", + ops.cutlass_scaled_mm, a, b, scale_a, scale_b, + torch.bfloat16)) + + # cutlass impl: bf16 output + timers.append( + bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_sparse_mm", + ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, + scale_b, torch.bfloat16)) + + # cutlass impl: fp16 output + timers.append( + bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_sparse_mm", + ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, + scale_b, torch.float16)) + + # cutlass impl: bf16 output, with bias + timers.append( + bench_fn(label, sub_label, + "cutlass_fp8_fp8_bf16_scaled_sparse_mm_bias", + ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, + scale_b, torch.bfloat16, bias)) + + # cutlass impl: fp16 output, with bias + timers.append( + bench_fn(label, sub_label, + "cutlass_fp8_fp8_fp16_scaled_sparse_mm_bias", + ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, + scale_b, torch.float16, bias.to(dtype=torch.float16))) + + return timers + + +def bench(dtype: torch.dtype, m: int, k: int, n: int, label: str, + sub_label: str) -> Iterable[TMeasurement]: + if dtype == torch.int8: + return bench_int8(dtype, m, k, n, label, sub_label) + if dtype == torch.float8_e4m3fn: + return bench_fp8(dtype, m, k, n, label, sub_label) + raise ValueError("unsupported type") + + +# runner +def print_timers(timers: Iterable[TMeasurement]): + compare = TBenchmark.Compare(timers) + compare.print() + + +def run(dtype: torch.dtype, + MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]: + results = [] + for m, k, n in MKNs: + timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm", + f"MKN=({m}x{k}x{n})") + print_timers(timers) + results.extend(timers) + + return results + + +# output makers +def make_output(data: Iterable[TMeasurement], + MKNs: Iterable[Tuple[int, int, int]], + base_description: str, + timestamp=None): + print(f"== All Results {base_description} ====") + print_timers(data) + + # pickle all the results + timestamp = int(time.time()) if timestamp is None else timestamp + with open(f"{base_description}-{timestamp}.pkl", "wb") as f: + pkl.dump(data, f) + + +# argparse runners + + +def run_square_bench(args): + dim_sizes = list( + range(args.dim_start, args.dim_end + 1, args.dim_increment)) + MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes)) + data = run(args.dtype, MKNs) + + make_output(data, MKNs, f"square_bench-{args.dtype}") + + +def run_range_bench(args): + dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment)) + n = len(dim_sizes) + Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes + Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes + Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes + MKNs = list(zip(Ms, Ks, Ns)) + data = run(args.dtype, MKNs) + + make_output(data, MKNs, f"range_bench-{args.dtype}") + + +def run_model_bench(args): + print("Benchmarking models:") + for i, model in enumerate(args.models): + print(f"[{i}] {model}") + + def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]: + KNs = [] + for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]): + KN[tp_split_dim] = KN[tp_split_dim] // tp_size + KNs.append(KN) + return KNs + + model_bench_data = [] + models_tps = list(itertools.product(args.models, args.tp_sizes)) + for model, tp_size in models_tps: + Ms = args.batch_sizes + KNs = model_shapes(model, tp_size) + MKNs = [] + for m in Ms: + for k, n in KNs: + MKNs.append((m, k, n)) + + data = run(args.dtype, MKNs) + model_bench_data.append(data) + + # Print all results + for data, model_tp in zip(model_bench_data, models_tps): + model, tp_size = model_tp + print(f"== Results {args.dtype} {model}-TP{tp_size} ====") + print_timers(data) + + timestamp = int(time.time()) + + all_data = [] + for d in model_bench_data: + all_data.extend(d) + # pickle all data + with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f: + pkl.dump(all_data, f) + + +if __name__ == '__main__': + + def to_torch_dtype(dt): + if dt == "int8": + return torch.int8 + if dt == "fp8": + return torch.float8_e4m3fn + raise ValueError("unsupported dtype") + + parser = FlexibleArgumentParser( + description=""" +Benchmark Cutlass GEMM. + + To run square GEMMs: + python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 square_bench --dim-start 128 --dim-end 512 --dim-increment 64 + + To run constant N and K and sweep M: + python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384 + + To run dimensions from a model: + python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1 + + Output: + - a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs. + """, # noqa: E501 + formatter_class=argparse.RawTextHelpFormatter) + + parser.add_argument("--dtype", + type=to_torch_dtype, + required=True, + help="Available options are ['int8', 'fp8']") + subparsers = parser.add_subparsers(dest="cmd") + + square_parser = subparsers.add_parser("square_bench") + square_parser.add_argument("--dim-start", type=int, required=True) + square_parser.add_argument("--dim-end", type=int, required=True) + square_parser.add_argument("--dim-increment", type=int, required=True) + square_parser.set_defaults(func=run_square_bench) + + range_parser = subparsers.add_parser("range_bench") + range_parser.add_argument("--dim-start", type=int, required=True) + range_parser.add_argument("--dim-end", type=int, required=True) + range_parser.add_argument("--dim-increment", type=int, required=True) + range_parser.add_argument("--m-constant", type=int, default=None) + range_parser.add_argument("--n-constant", type=int, default=None) + range_parser.add_argument("--k-constant", type=int, default=None) + range_parser.set_defaults(func=run_range_bench) + + model_parser = subparsers.add_parser("model_bench") + model_parser.add_argument("--models", + nargs="+", + type=str, + default=DEFAULT_MODELS, + choices=WEIGHT_SHAPES.keys()) + model_parser.add_argument("--tp-sizes", + nargs="+", + type=int, + default=DEFAULT_TP_SIZES) + model_parser.add_argument("--batch-sizes", + nargs="+", + type=int, + default=DEFAULT_BATCH_SIZES) + model_parser.set_defaults(func=run_model_bench) + + args = parser.parse_args() + args.func(args) diff --git a/benchmarks/cutlass_benchmarks/utils.py b/benchmarks/cutlass_benchmarks/utils.py new file mode 100644 index 0000000000000..ef06fcd6604dd --- /dev/null +++ b/benchmarks/cutlass_benchmarks/utils.py @@ -0,0 +1,96 @@ +# Cutlass bench utils +from typing import Iterable, Tuple + +import torch + +import vllm._custom_ops as ops + + +def to_fp8(tensor: torch.Tensor) -> torch.Tensor: + finfo = torch.finfo(torch.float8_e4m3fn) + return torch.round(tensor.clamp( + min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) + + +def to_int8(tensor: torch.Tensor) -> torch.Tensor: + return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8) + + +def to_bf16(tensor: torch.Tensor) -> torch.Tensor: + return tensor.to(dtype=torch.bfloat16) + + +def to_fp16(tensor: torch.Tensor) -> torch.Tensor: + return tensor.to(dtype=torch.float16) + + +def make_rand_tensors(dtype: torch.dtype, m: int, n: int, + k: int) -> Tuple[torch.Tensor, torch.Tensor]: + a = torch.randn((m, k), device='cuda') * 5 + b = torch.randn((n, k), device='cuda').t() * 5 + + if dtype == torch.int8: + return to_int8(a), to_int8(b) + if dtype == torch.float8_e4m3fn: + return to_fp8(a), to_fp8(b) + + raise ValueError("unsupported dtype") + + +def prune_to_2_4(tensor): + # Reshape tensor to [N, 4] where N is number of groups of 4 + original_shape = tensor.shape + reshaped = tensor.reshape(-1, 4) + + # Get indices of top 2 absolute values in each group of 4 + _, indices = torch.topk(torch.abs(reshaped), k=2, dim=1) + + # Create binary mask + mask = torch.zeros_like(reshaped) + mask.scatter_(dim=1, + index=indices, + src=torch.ones_like(indices, dtype=mask.dtype)) + + # Apply mask and reshape back + pruned = reshaped * mask + + # Turn all -0.0 to 0.0 + pruned[pruned == -0.0] = 0.0 + + return pruned.reshape(original_shape) + + +def make_rand_sparse_tensors(dtype: torch.dtype, m: int, n: int, + k: int) -> Tuple[torch.Tensor, torch.Tensor]: + a = torch.randn((m, k), device='cuda') * 5 + b = torch.randn((n, k), device='cuda').t() * 5 + + b = prune_to_2_4(b.t()).t() + + if dtype == torch.int8: + a, b = to_int8(a), to_int8(b) + elif dtype == torch.float8_e4m3fn: + a, b = to_fp8(a), to_fp8(b) + elif dtype == torch.float16: + a, b = to_fp16(a), to_fp16(b) + elif dtype == torch.bfloat16: + a, b = to_bf16(a), to_bf16(b) + else: + raise ValueError("unsupported dtype") + + b_compressed, e = ops.cutlass_sparse_compress(b.t()) + + # Compressed B, Metadata, Original A, B + return b_compressed, e, a, b + + +def make_n_rand_sparse_tensors(num_tensors: int, dtype: torch.dtype, + m: int, n: int, k: int) -> \ + Tuple[Iterable[torch.Tensor], Iterable[torch.Tensor]]: + ABs = [] + for _ in range(num_tensors): + b_comp, e, a, b = make_rand_sparse_tensors(dtype, m, n, k) + if b_comp is not None: + ABs.append(make_rand_sparse_tensors(dtype, m, n, k)) + BComps, Es, As, Bs = zip(*ABs) + return list(BComps), list(Es), list(As), list(Bs) diff --git a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py index 63cf5d50cac75..d0353bc8cb42a 100644 --- a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py +++ b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py @@ -8,6 +8,7 @@ import torch import torch.utils.benchmark as TBenchmark from torch.utils.benchmark import Measurement as TMeasurement +from utils import make_rand_tensors from weight_shapes import WEIGHT_SHAPES from vllm import _custom_ops as ops @@ -17,31 +18,6 @@ DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512] DEFAULT_TP_SIZES = [1] -# helpers - - -def to_fp8(tensor: torch.Tensor) -> torch.Tensor: - finfo = torch.finfo(torch.float8_e4m3fn) - return torch.round(tensor.clamp( - min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) - - -def to_int8(tensor: torch.Tensor) -> torch.Tensor: - return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8) - - -def make_rand_tensors(dtype: torch.dtype, m: int, n: int, - k: int) -> Tuple[torch.Tensor, torch.Tensor]: - a = torch.randn((m, k), device='cuda') * 5 - b = torch.randn((n, k), device='cuda').t() * 5 - - if dtype == torch.int8: - return to_int8(a), to_int8(b) - if dtype == torch.float8_e4m3fn: - return to_fp8(a), to_fp8(b) - - raise ValueError("unsupported dtype") - # bench def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args, @@ -386,4 +362,4 @@ def to_torch_dtype(dt): model_parser.set_defaults(func=run_model_bench) args = parser.parse_args() - args.func(args) + args.func(args) \ No newline at end of file diff --git a/benchmarks/cutlass_benchmarks/weight_shapes.py b/benchmarks/cutlass_benchmarks/weight_shapes.py index 25ec9d6028627..d58fb0bf86374 100644 --- a/benchmarks/cutlass_benchmarks/weight_shapes.py +++ b/benchmarks/cutlass_benchmarks/weight_shapes.py @@ -40,4 +40,4 @@ ([8192, 57344], 1), ([28672, 8192], 0), ], -} +} \ No newline at end of file diff --git a/csrc/core/math.hpp b/csrc/core/math.hpp new file mode 100644 index 0000000000000..ba9f40a230c8e --- /dev/null +++ b/csrc/core/math.hpp @@ -0,0 +1,7 @@ +#include +#include + +inline uint32_t next_pow_2(uint32_t const num) { + if (num <= 1) return num; + return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); +} \ No newline at end of file diff --git a/csrc/cutlass_extensions/common.cpp b/csrc/cutlass_extensions/common.cpp new file mode 100644 index 0000000000000..3d2093ab94297 --- /dev/null +++ b/csrc/cutlass_extensions/common.cpp @@ -0,0 +1,11 @@ +#include "cutlass_extensions/common.hpp" + +int32_t get_sm_version_num() { + int32_t major_capability, minor_capability; + cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor, + 0); + cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor, + 0); + int32_t version_num = major_capability * 10 + minor_capability; + return version_num; +} \ No newline at end of file diff --git a/csrc/cutlass_extensions/common.hpp b/csrc/cutlass_extensions/common.hpp new file mode 100644 index 0000000000000..85e359aa57113 --- /dev/null +++ b/csrc/cutlass_extensions/common.hpp @@ -0,0 +1,35 @@ +#pragma once + +#include "cutlass/cutlass.h" +#include +#include "cuda_runtime.h" +#include + +/** + * Helper function for checking CUTLASS errors + */ +#define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + TORCH_CHECK(error == cutlass::Status::kSuccess, \ + cutlassGetStatusString(error)); \ + } + +/** + * Panic wrapper for unwinding CUDA runtime errors + */ +#define CUDA_CHECK(status) \ + { \ + cudaError_t error = status; \ + TORCH_CHECK(error == cudaSuccess, cudaGetErrorString(error)); \ + } + +inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) { + int max_shared_mem_per_block_opt_in = 0; + cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in, + cudaDevAttrMaxSharedMemoryPerBlockOptin, + device); + return max_shared_mem_per_block_opt_in; +} + +int32_t get_sm_version_num(); diff --git a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp index 95764ecddc79f..fcc17c7727f94 100644 --- a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp +++ b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp @@ -36,13 +36,13 @@ struct ScaledEpilogueBase { // Don't want to support nullptr by default template using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast< - 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, T, Stride, Int<0>, Int<0>>, 128 / sizeof_bits_v, EnableNullPtr>; // Don't want to support nullptr by default template using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast< - 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, T, Stride, Int<1>, Int<0>>, 128 / sizeof_bits_v, EnableNullPtr>; // This utility function constructs the arguments for the load descriptors diff --git a/csrc/ops.h b/csrc/ops.h index 816b471d062d2..c145e4eda0845 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -162,6 +162,15 @@ void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& azp_adj, c10::optional const& azp, c10::optional const& bias); + +void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, torch::Tensor const& e, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + c10::optional const& bias); + +bool cutlass_sparse_compress_entry(torch::Tensor& a_compressed, + torch::Tensor& e, torch::Tensor const& a); #endif void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, diff --git a/csrc/quantization/cutlass_w8a8/common.hpp b/csrc/quantization/cutlass_w8a8/common.hpp deleted file mode 100644 index bf04bb400790f..0000000000000 --- a/csrc/quantization/cutlass_w8a8/common.hpp +++ /dev/null @@ -1,27 +0,0 @@ -#pragma once - -#include "cutlass/cutlass.h" -#include - -/** - * Helper function for checking CUTLASS errors - */ -#define CUTLASS_CHECK(status) \ - { \ - TORCH_CHECK(status == cutlass::Status::kSuccess, \ - cutlassGetStatusString(status)) \ - } - -inline uint32_t next_pow_2(uint32_t const num) { - if (num <= 1) return num; - return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); -} - -inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) { - int max_shared_mem_per_block_opt_in = 0; - cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in, - cudaDevAttrMaxSharedMemoryPerBlockOptin, - device); - return max_shared_mem_per_block_opt_in; -} - diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh index d03242f44ab1d..75681f7f37820 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh @@ -21,7 +21,8 @@ #include "cutlass/epilogue/threadblock/fusion/visitors.hpp" #include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h" -#include "common.hpp" +#include "core/math.hpp" +#include "cutlass_extensions/common.hpp" // clang-format on using namespace cute; diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu index 33581a63d4c3d..8190277997161 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu @@ -24,7 +24,8 @@ #include "cutlass/gemm/collective/collective_builder.hpp" #include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" -#include "common.hpp" +#include "core/math.hpp" +#include "cutlass_extensions/common.hpp" // clang-format on using namespace cute; diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu index 97a969cf5e3e0..4f7b6588ef3f7 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu @@ -3,6 +3,8 @@ #include #include +#include "cutlass_extensions/common.hpp" + void cutlass_scaled_mm_sm75(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, @@ -79,16 +81,6 @@ bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) { return false; } -int32_t get_sm_version_num() { - int32_t major_capability, minor_capability; - cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor, - 0); - cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor, - 0); - int32_t version_num = major_capability * 10 + minor_capability; - return version_num; -} - void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, torch::Tensor const& b_scales, diff --git a/csrc/sparse/cutlass/sparse_compressor_c3x.cu b/csrc/sparse/cutlass/sparse_compressor_c3x.cu new file mode 100644 index 0000000000000..218c5317b4de6 --- /dev/null +++ b/csrc/sparse/cutlass/sparse_compressor_c3x.cu @@ -0,0 +1,163 @@ +// clang-format will break include orders +// clang-format off +#include + +#include "sparse_scaled_mm_c3x.cuh" + +#include "cutlass/numeric_conversion.h" +#include "cutlass/transform/device/transform_universal_adapter.hpp" +#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +// clang-format on + +using namespace cute; +using namespace vllm; + +/// Make A structured sparse by replacing elements with 0 and compress it +template +bool cutlass_sparse_compress(torch::Tensor& a_nzs, torch::Tensor& a_meta, + torch::Tensor const& a) { + // Checks for conformality + TORCH_CHECK(a.dtype() == torch::kInt8 || a.dtype() == torch::kFloat8_e4m3fn || + a.dtype() == torch::kFloat16 || a.dtype() == torch::kBFloat16); + TORCH_CHECK(a.dim() == 2) + // Check for strides and alignment + TORCH_CHECK(a.stride(0) % 4 == 0) // Required for semi-structured sparsity + TORCH_CHECK(a.stride(1) == 1) + + int m = a.size(0); + int k = a.size(1); + + // Sparse kernel setup; this kernel is not used for matmul, + // but just for setting up the compressor utility + // A matrix configuration + using ElementA = ElementA_; + using LayoutTagA = cutlass::layout::RowMajor; + constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + // B matrix configuration + using ElementB = ElementA; + using LayoutTagB = cutlass::layout::ColumnMajor; + constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; + // C/D matrix configuration + using ElementC = float; + using LayoutTagC = cutlass::layout::ColumnMajor; + constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + // Core kernel configurations + using ElementAccumulator = ElementAcc_; + using TileShape = Shape<_128, _128, _128>; + using TileShapeRef = Shape<_128, _128, _64>; + using ClusterShape = Shape<_1, _2, _1>; + using KernelSchedule = typename std::conditional< + std::is_same_v, + cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum, + cutlass::gemm::KernelTmaWarpSpecialized>::type; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using ProblemShape = Shape; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, + ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, ElementC, LayoutTagC, + AlignmentC, ElementC, LayoutTagC, AlignmentC, + EpilogueSchedule>::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassSparseTensorOp, ElementA, + LayoutTagA, AlignmentA, ElementB, LayoutTagB, AlignmentB, + ElementAccumulator, TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + + using GemmKernel = + cutlass::gemm::kernel::GemmUniversal; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using StrideA = cutlass::gemm::TagToStrideA_t; + using StrideE = StrideA; + + using StrideA = Stride, int64_t>; + + // The n (=1) dimension does not matter for the compressor + typename GemmKernel::ProblemShape prob_shape{m, 1, k, 1}; + + using LayoutA = typename GemmKernel::CollectiveMainloop::LayoutA; + using LayoutE = typename GemmKernel::CollectiveMainloop::LayoutE; + + using ElementE = typename GemmKernel::CollectiveMainloop::ElementE; + using SparseConfig = typename GemmKernel::CollectiveMainloop::SparseConfig; + + // Offline compressor kernel + using CompressorUtility = + cutlass::transform::kernel::StructuredSparseCompressorUtility< + ProblemShape, ElementA, LayoutTagA, SparseConfig>; + + using CompressorKernel = + cutlass::transform::kernel::StructuredSparseCompressor< + ProblemShape, ElementA, LayoutTagA, SparseConfig, + cutlass::arch::Sm90>; + + using Compressor = + cutlass::transform::device::TransformUniversalAdapter; + + auto [M, N, K, L] = prob_shape; + + StrideA stride_A; + stride_A = + cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + + CompressorUtility compressor_utility(prob_shape, stride_A); + + int ME = compressor_utility.get_metadata_m_physical(); + int KE = compressor_utility.get_metadata_k_physical(); + int KC = compressor_utility.get_tensorA_k_physical(); + + auto a_ptr = static_cast(a.data_ptr()); + + auto a_nzs_ptr = static_cast(a_nzs.data_ptr()); + auto a_meta_ptr = static_cast( + a_meta.data_ptr()); + + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = + cutlass::KernelHardwareInfo::query_device_multiprocessor_count( + hw_info.device_id); + typename Compressor::Arguments arguments{ + prob_shape, {a_ptr, stride_A, a_nzs_ptr, a_meta_ptr}, {hw_info}}; + + Compressor compressor_op; + size_t workspace_size = Compressor::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + CUTLASS_CHECK(compressor_op.can_implement(arguments)); + CUTLASS_CHECK(compressor_op.initialize(arguments, workspace.get())); + CUTLASS_CHECK(compressor_op.run()); + CUDA_CHECK(cudaDeviceSynchronize()); + + return true; +} + +bool cutlass_sparse_compress_sm90(torch::Tensor& a_nzs, torch::Tensor& a_meta, + torch::Tensor const& a) { + if (a.dtype() == torch::kBFloat16) { + return cutlass_sparse_compress(a_nzs, a_meta, + a); + } else if (a.dtype() == torch::kFloat16) { + return cutlass_sparse_compress(a_nzs, a_meta, a); + } else if (a.dtype() == torch::kFloat8_e4m3fn) { + return cutlass_sparse_compress(a_nzs, a_meta, + a); + } else if (a.dtype() == torch::kInt8) { + return cutlass_sparse_compress(a_nzs, a_meta, a); + } + return false; +} \ No newline at end of file diff --git a/csrc/sparse/cutlass/sparse_compressor_entry.cu b/csrc/sparse/cutlass/sparse_compressor_entry.cu new file mode 100644 index 0000000000000..d23d937b6ac28 --- /dev/null +++ b/csrc/sparse/cutlass/sparse_compressor_entry.cu @@ -0,0 +1,42 @@ +#include + +#include +#include + +#include "cutlass_extensions/common.hpp" + +#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X +bool cutlass_sparse_compress_sm90(torch::Tensor& a_nzs, torch::Tensor& a_meta, + torch::Tensor const& a); +#endif + +bool cutlass_sparse_compress_entry(torch::Tensor& a_nzs, torch::Tensor& a_meta, + torch::Tensor const& a) { + // Checks for conformality + TORCH_CHECK(a.dim() == 2 && a_meta.dim() == 2 && a_nzs.dim() == 2); + TORCH_CHECK(a.size(0) == a_nzs.size(0) && a.size(0) == a_meta.size(0) && + a_nzs.size(1) * 2 == a.size(1) && + a_meta.size(1) * 2 * 4 == a.size(1)); + // Considering elemsPerMetaElem = 8b / 2b_per_nz = 4 + + // Check for strides and alignment + TORCH_CHECK(a.stride(1) == 1 && a_nzs.stride(1) == 1 && + a_meta.stride(1) == 1); // Row-major + TORCH_CHECK(a.stride(0) % 8 == 0); // 8 Byte Alignment for Compression + + at::cuda::OptionalCUDAGuard const device_guard(device_of(a)); + int32_t version_num = get_sm_version_num(); + + // Guard against compilation issues for sm90 kernels +#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X + if (version_num >= 90) { + return cutlass_sparse_compress_sm90(a_nzs, a_meta, a); + } +#endif + + TORCH_CHECK_NOT_IMPLEMENTED( + false, + "No compiled cutlass_scaled_sparse_mm for a compute capability less than " + "CUDA device capability: ", + version_num); +} diff --git a/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu b/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu new file mode 100644 index 0000000000000..b50e9a3a2c240 --- /dev/null +++ b/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu @@ -0,0 +1,303 @@ +// clang-format will break include orders +// clang-format off +#include + +#if defined CUDA_VERSION && CUDA_VERSION >= 12000 +#include "sparse_scaled_mm_c3x.cuh" +// clang-format on + +using namespace cute; +using namespace vllm; + +template typename Epilogue, + typename... EpilogueArgs> +void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& bt_nzs, + torch::Tensor const& bt_meta, + EpilogueArgs&&... args) { + static_assert(std::is_same()); + TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(bt_meta.dtype() == torch::kUInt8); + TORCH_CHECK(bt_nzs.dtype() == torch::kFloat8_e4m3fn); + + using Cutlass3xGemmDefault = + typename sm90_config_default::Cutlass3xGemm; + using Cutlass3xGemmM64 = + typename sm90_fp8_config_M64::Cutlass3xGemm; + using Cutlass3xGemmM128 = + typename sm90_fp8_config_M128::Cutlass3xGemm; + using Cutlass3xGemmM256 = + typename sm90_fp8_config_M256::Cutlass3xGemm; + using Cutlass3xGemmM512 = + typename sm90_fp8_config_M512::Cutlass3xGemm; + + using Cutlass3xGemm1 = + typename sm90_fp8_config_1::Cutlass3xGemm; + using Cutlass3xGemm2 = + typename sm90_fp8_config_2::Cutlass3xGemm; + using Cutlass3xGemm3 = + typename sm90_fp8_config_3::Cutlass3xGemm; + using Cutlass3xGemm4 = + typename sm90_fp8_config_4::Cutlass3xGemm; + using Cutlass3xGemm5 = + typename sm90_fp8_config_5::Cutlass3xGemm; + using Cutlass3xGemm6 = + typename sm90_fp8_config_6::Cutlass3xGemm; + using Cutlass3xGemm7 = + typename sm90_fp8_config_7::Cutlass3xGemm; + using Cutlass3xGemm8 = + typename sm90_fp8_config_8::Cutlass3xGemm; + + uint32_t const n = bt_nzs.size(0); + uint32_t const m = a.size(0); // Batch size + uint32_t const mp2 = + std::max(static_cast(64), next_pow_2(m)); // next power of 2 + + if (mp2 <= 64) { + if (n == 28672) { + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } else if (n == 4096 || n == 6144) { + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } + } else if (mp2 <= 128) { + if (n == 4096) { + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } else if (n == 28672) { + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } else if (n == 6144) { + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } + } else if (mp2 <= 256) { + if (n == 4096) { + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } else if (n == 28672) { + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } else if (n == 6144) { + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } + } else { + if (n == 6144 || n == 28672) { + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } else if (n == 4096) { + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } + } + + // Otherwise the default heuristic + if (mp2 <= 64) { + // n in [1, 64] + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } else if (mp2 <= 128) { + // n in (64, 128] + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } else if (mp2 <= 256) { + // n in (128, 256] + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } else { + // n in (256, inf) + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } +} + +template typename Epilogue, + typename... EpilogueArgs> +void cutlass_gemm_sm90_fp16_dispatch(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& bt_nzs, + torch::Tensor const& bt_meta, + EpilogueArgs&&... args) { + static_assert(std::is_same()); + TORCH_CHECK(a.dtype() == torch::kFloat16); + TORCH_CHECK(bt_meta.dtype() == torch::kUInt8); + TORCH_CHECK(bt_nzs.dtype() == torch::kFloat16); + + using Cutlass3xGemmDefault = + typename sm90_config_default::Cutlass3xGemm; + + // m in (128, inf) + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); +} + +template typename Epilogue, + typename... EpilogueArgs> +void cutlass_gemm_sm90_bf16_dispatch(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& bt_nzs, + torch::Tensor const& bt_meta, + EpilogueArgs&&... args) { + static_assert(std::is_same()); + TORCH_CHECK(a.dtype() == torch::kBFloat16); + TORCH_CHECK(bt_meta.dtype() == torch::kUInt8); + TORCH_CHECK(bt_nzs.dtype() == torch::kBFloat16); + + using Cutlass3xGemmDefault = + typename sm90_config_default::Cutlass3xGemm; + + // m in (128, inf) + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); +} + +template typename Epilogue, + typename... EpilogueArgs> +void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& bt_nzs, + torch::Tensor const& bt_meta, + EpilogueArgs&&... args) { + static_assert(std::is_same()); + TORCH_CHECK(a.dtype() == torch::kInt8); + TORCH_CHECK(bt_meta.dtype() == torch::kUInt8); + TORCH_CHECK(bt_nzs.dtype() == torch::kInt8); + + using Cutlass3xGemmDefault = + typename sm90_config_default::Cutlass3xGemm; + using Cutlass3xGemmM128 = + typename sm90_int8_config_M128::Cutlass3xGemm; + using Cutlass3xGemmM64 = + typename sm90_int8_config_M64::Cutlass3xGemm; + using Cutlass3xGemmM32NBig = + typename sm90_int8_config_M32_NBig::Cutlass3xGemm; + using Cutlass3xGemmM32NSmall = + typename sm90_int8_config_M32_NSmall::Cutlass3xGemm; + + uint32_t const n = out.size(1); + bool const is_small_n = n < 8192; + + uint32_t const m = a.size(0); + uint32_t const mp2 = + std::max(static_cast(32), next_pow_2(m)); // next power of 2 + + if (mp2 <= 32) { + // m in [1, 32] + if (is_small_n) { + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } else { + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } + } else if (mp2 <= 64) { + // m in (32, 64] + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } else if (mp2 <= 128) { + // m in (64, 128] + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } else { + // m in (128, inf) + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } +} + +template