diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 89c7e8dafe0..5b8d4419245 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -9,6 +9,7 @@ from transformers import (AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase) from tqdm import tqdm +from torch.distributed import launch def sample_requests( @@ -93,7 +94,7 @@ def run_vllm( scales_path=scales_path, device=device, enable_prefix_caching=enable_prefix_caching, - worker_use_torchrun=args.worker_use_torchrun,) + worker_use_torchrun=worker_use_torchrun,) # Add the requests to the engine. for prompt, _, output_len in requests: