Skip to content

Commit

Permalink
openai_server chat_api add force_answer_prefix_token_ids, output_log_…
Browse files Browse the repository at this point in the history
…prob_token_id, TopLogProb add eos
  • Loading branch information
SomeoneKong committed Aug 3, 2024
1 parent 46544eb commit 837e164
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 5 deletions.
13 changes: 12 additions & 1 deletion vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,16 @@ class ChatCompletionRequest(OpenAIBaseModel):
description=(
"If specified, will override the default whitespace pattern "
"for guided json decoding."))

force_answer_prefix_token_ids: Optional[List[int]] = Field(
default=None,
description=(
"If specified, the output will be prefixed with the given tokens"),
)
output_log_prob_token_id: Optional[bool] = Field(
default=False,
description=(
"If specified, the output log_prob will add token_id"),
)
# doc: end-chat-completion-extra-params

def to_sampling_params(
Expand Down Expand Up @@ -587,6 +596,8 @@ class ChatCompletionLogProb(OpenAIBaseModel):
token: str
logprob: float = -9999.0
bytes: Optional[List[int]] = None
token_id: Optional[int] = None
eos: Optional[bool] = False


class ChatCompletionLogProbsContent(ChatCompletionLogProb):
Expand Down
33 changes: 29 additions & 4 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,10 @@ async def create_chat_completion(
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)

if request.force_answer_prefix_token_ids:
prompt_inputs["prompt_token_ids"] += request.force_answer_prefix_token_ids
prompt_inputs["prompt"] += tokenizer.decode(request.force_answer_prefix_token_ids)

engine_inputs: PromptInputs = {
"prompt_token_ids": prompt_inputs["prompt_token_ids"],
}
Expand Down Expand Up @@ -176,15 +180,20 @@ async def create_chat_completion(
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))

sampling_params = sampling_params.clone()
sampling_params.update_from_generation_config(
self.engine.engine.generation_config_fields)
stop_token_ids = sampling_params.stop_token_ids

# Streaming response
if request.stream:
return self.chat_completion_stream_generator(
request, result_generator, request_id, conversation, tokenizer)
request, result_generator, request_id, conversation, tokenizer, stop_token_ids)
else:
try:
return await self.chat_completion_full_generator(
request, raw_request, result_generator, request_id,
conversation, tokenizer)
conversation, tokenizer, stop_token_ids)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
Expand All @@ -202,6 +211,7 @@ async def chat_completion_stream_generator(
request_id: str,
conversation: List[ConversationMessage],
tokenizer: PreTrainedTokenizer,
stop_token_ids: List[int]
) -> AsyncGenerator[str, None]:
model_name = self.served_model_names[0]
created_time = int(time.time())
Expand Down Expand Up @@ -310,6 +320,8 @@ async def chat_completion_stream_generator(
top_logprobs=out_logprobs,
tokenizer=tokenizer,
num_output_top_logprobs=request.top_logprobs,
add_token_id=request.output_log_prob_token_id or False,
stop_token_ids=stop_token_ids,
)
else:
logprobs = None
Expand Down Expand Up @@ -427,6 +439,7 @@ async def chat_completion_full_generator(
request_id: str,
conversation: List[ConversationMessage],
tokenizer: PreTrainedTokenizer,
stop_token_ids: List[int]
) -> Union[ErrorResponse, ChatCompletionResponse]:

model_name = self.served_model_names[0]
Expand Down Expand Up @@ -455,6 +468,8 @@ async def chat_completion_full_generator(
top_logprobs=out_logprobs,
num_output_top_logprobs=request.top_logprobs,
tokenizer=tokenizer,
add_token_id=request.output_log_prob_token_id or False,
stop_token_ids=stop_token_ids,
)
else:
logprobs = None
Expand Down Expand Up @@ -510,13 +525,18 @@ async def chat_completion_full_generator(

def _get_top_logprobs(
self, logprobs: Dict[int, Logprob], top_logprobs: Optional[int],
tokenizer: PreTrainedTokenizer) -> List[ChatCompletionLogProb]:
tokenizer: PreTrainedTokenizer,
add_token_id: bool = False,
stop_token_ids: List[int] = None
) -> List[ChatCompletionLogProb]:
return [
ChatCompletionLogProb(token=(token := self._get_decoded_token(
p[1],
p[0],
tokenizer,
return_as_token_id=self.return_tokens_as_token_ids)),
token_id=p[0] if add_token_id else None,
eos=p[0] in stop_token_ids if add_token_id else None,
logprob=max(p[1].logprob, -9999.0),
bytes=list(
token.encode("utf-8", errors="replace")))
Expand All @@ -530,20 +550,24 @@ def _create_chat_logprobs(
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
tokenizer: PreTrainedTokenizer,
num_output_top_logprobs: Optional[int] = None,
add_token_id: bool = False,
stop_token_ids: List[int] = None,
) -> ChatCompletionLogProbs:
"""Create OpenAI-style logprobs."""

logprobs_content = []

for i, token_id in enumerate(token_ids):
step_top_logprobs = top_logprobs[i]
output_token_id = token_id if add_token_id else None
if step_top_logprobs is None:
token = tokenizer.decode(token_id)
if self.return_tokens_as_token_ids:
token = f"token_id:{token_id}"
logprobs_content.append(
ChatCompletionLogProbsContent(
token=token,
token_id=output_token_id,
bytes=list(token.encode("utf-8", errors="replace"))))
else:
logprobs_content.append(
Expand All @@ -553,11 +577,12 @@ def _create_chat_logprobs(
self.return_tokens_as_token_ids),
logprob=max(step_top_logprobs[token_id].logprob,
-9999.0),
token_id=output_token_id,
bytes=list(
step_top_logprobs[token_id].decoded_token.encode(
"utf-8", errors="replace")),
top_logprobs=self._get_top_logprobs(
step_top_logprobs, num_output_top_logprobs,
tokenizer)))
tokenizer, add_token_id, stop_token_ids)))

return ChatCompletionLogProbs(content=logprobs_content)

0 comments on commit 837e164

Please sign in to comment.