@@ -55,7 +55,7 @@ def __init__(
55
55
) -> None :
56
56
if "disable_log_stats" not in kwargs :
57
57
kwargs ["disable_log_stats" ] = True
58
- self . engine_args = AsyncEngineArgs (
58
+ engine_args = AsyncEngineArgs (
59
59
model = model ,
60
60
tokenizer = tokenizer ,
61
61
tokenizer_mode = tokenizer_mode ,
@@ -76,6 +76,8 @@ def __init__(
76
76
** kwargs ,
77
77
)
78
78
self .request_counter = Counter ()
79
+ self .llm_engine = AsyncLLMEngine .from_engine_args (
80
+ engine_args , usage_context = UsageContext .LLM_CLASS )
79
81
80
82
def generate (
81
83
self ,
@@ -88,9 +90,6 @@ def generate(
88
90
multi_modal_data : Optional [MultiModalData ] = None ,
89
91
) -> List [RequestOutput ]:
90
92
91
- llm_engine = AsyncLLMEngine .from_engine_args (
92
- self .engine_args , usage_context = UsageContext .LLM_CLASS )
93
-
94
93
if prompts is None :
95
94
raise ValueError ("prompts must be provided." )
96
95
if isinstance (prompts , str ):
@@ -111,8 +110,8 @@ def generate(
111
110
112
111
async def get_output (prompt , sampling_param ) -> str :
113
112
request_id = random_uuid ()
114
- results_generator = llm_engine .generate (prompt , sampling_param ,
115
- request_id )
113
+ results_generator = self . llm_engine .generate (
114
+ prompt , sampling_param , request_id )
116
115
final_output = None
117
116
async for request_output in results_generator :
118
117
final_output = request_output
@@ -185,12 +184,25 @@ def generator_outer():
185
184
return generator_outer
186
185
187
186
187
+ def maybe_assert_ngram_worker (llm ):
188
+ # Verify the proposer worker is ngram if ngram is specified.
189
+ if (not isinstance (llm , AsyncLLM )
190
+ and llm .llm_engine .speculative_config is not None
191
+ and llm .llm_engine .speculative_config .ngram_prompt_lookup_max > 0 ):
192
+ from vllm .spec_decode .ngram_worker import NGramWorker
193
+ assert isinstance (
194
+ llm .llm_engine .model_executor .driver_worker .proposer_worker ,
195
+ NGramWorker )
196
+
197
+
188
198
def get_output_from_llm_generator (
189
199
llm_generator , prompts ,
190
200
sampling_params ) -> Tuple [List [str ], List [List [int ]]]:
191
201
tokens = []
192
202
token_ids = []
193
203
for llm in llm_generator ():
204
+ maybe_assert_ngram_worker (llm )
205
+
194
206
outputs = llm .generate (prompts , sampling_params , use_tqdm = True )
195
207
token_ids = [output .outputs [0 ].token_ids for output in outputs ]
196
208
tokens = [output .outputs [0 ].text for output in outputs ]
0 commit comments