Skip to content

Commit 8344f77

Browse files
leiwen83wenlei03cadedanielcomaniac
authored
[Bug fix][Core] fixup ngram not setup correctly (vllm-project#4551)
Co-authored-by: Lei Wen <wenlei03@qiyi.com> Co-authored-by: Cade Daniel <edacih@gmail.com> Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
1 parent 469f85c commit 8344f77

File tree

3 files changed

+29
-13
lines changed

3 files changed

+29
-13
lines changed

tests/spec_decode/e2e/conftest.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def __init__(
5555
) -> None:
5656
if "disable_log_stats" not in kwargs:
5757
kwargs["disable_log_stats"] = True
58-
self.engine_args = AsyncEngineArgs(
58+
engine_args = AsyncEngineArgs(
5959
model=model,
6060
tokenizer=tokenizer,
6161
tokenizer_mode=tokenizer_mode,
@@ -76,6 +76,8 @@ def __init__(
7676
**kwargs,
7777
)
7878
self.request_counter = Counter()
79+
self.llm_engine = AsyncLLMEngine.from_engine_args(
80+
engine_args, usage_context=UsageContext.LLM_CLASS)
7981

8082
def generate(
8183
self,
@@ -88,9 +90,6 @@ def generate(
8890
multi_modal_data: Optional[MultiModalData] = None,
8991
) -> List[RequestOutput]:
9092

91-
llm_engine = AsyncLLMEngine.from_engine_args(
92-
self.engine_args, usage_context=UsageContext.LLM_CLASS)
93-
9493
if prompts is None:
9594
raise ValueError("prompts must be provided.")
9695
if isinstance(prompts, str):
@@ -111,8 +110,8 @@ def generate(
111110

112111
async def get_output(prompt, sampling_param) -> str:
113112
request_id = random_uuid()
114-
results_generator = llm_engine.generate(prompt, sampling_param,
115-
request_id)
113+
results_generator = self.llm_engine.generate(
114+
prompt, sampling_param, request_id)
116115
final_output = None
117116
async for request_output in results_generator:
118117
final_output = request_output
@@ -185,12 +184,25 @@ def generator_outer():
185184
return generator_outer
186185

187186

187+
def maybe_assert_ngram_worker(llm):
188+
# Verify the proposer worker is ngram if ngram is specified.
189+
if (not isinstance(llm, AsyncLLM)
190+
and llm.llm_engine.speculative_config is not None
191+
and llm.llm_engine.speculative_config.ngram_prompt_lookup_max > 0):
192+
from vllm.spec_decode.ngram_worker import NGramWorker
193+
assert isinstance(
194+
llm.llm_engine.model_executor.driver_worker.proposer_worker,
195+
NGramWorker)
196+
197+
188198
def get_output_from_llm_generator(
189199
llm_generator, prompts,
190200
sampling_params) -> Tuple[List[str], List[List[int]]]:
191201
tokens = []
192202
token_ids = []
193203
for llm in llm_generator():
204+
maybe_assert_ngram_worker(llm)
205+
194206
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
195207
token_ids = [output.outputs[0].token_ids for output in outputs]
196208
tokens = [output.outputs[0].text for output in outputs]

vllm/executor/gpu_executor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,10 @@ def _init_spec_worker(self):
8282
draft_worker_kwargs.update(
8383
model_config=self.speculative_config.draft_model_config,
8484
parallel_config=self.speculative_config.draft_parallel_config,
85+
ngram_prompt_lookup_max=self.speculative_config.
86+
ngram_prompt_lookup_max,
87+
ngram_prompt_lookup_min=self.speculative_config.
88+
ngram_prompt_lookup_min,
8589
# TODO allow draft-model specific load config.
8690
#load_config=self.load_config,
8791
)

vllm/spec_decode/spec_decode_worker.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,10 @@ def create_worker(
5757
draft_worker_kwargs,
5858
) -> "SpecDecodeWorker":
5959

60-
if "ngram_prompt_lookup_max" in draft_worker_kwargs:
61-
ngram_prompt_lookup_max = (
62-
draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
63-
ngram_prompt_lookup_min = (
64-
draft_worker_kwargs.pop("ngram_prompt_lookup_min"))
65-
else:
66-
ngram_prompt_lookup_max = 0
60+
ngram_prompt_lookup_max = (
61+
draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
62+
ngram_prompt_lookup_min = (
63+
draft_worker_kwargs.pop("ngram_prompt_lookup_min"))
6764

6865
if ngram_prompt_lookup_max > 0:
6966
proposer_worker = NGramWorker(**draft_worker_kwargs)
@@ -72,6 +69,9 @@ def create_worker(
7269
else:
7370
proposer_worker = MultiStepWorker(**draft_worker_kwargs)
7471

72+
logger.info("Configuring SpecDecodeWorker with proposer=%s",
73+
type(proposer_worker))
74+
7575
return SpecDecodeWorker(
7676
proposer_worker,
7777
scorer_worker,

0 commit comments

Comments
 (0)