Skip to content

add usage statistics for inference API #894

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions llama_stack/apis/inference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,13 @@ class GrammarResponseFormat(BaseModel):
)


@json_schema_type
class UsageStatistics(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int


@json_schema_type
class CompletionRequest(BaseModel):
model: str
Expand All @@ -204,6 +211,7 @@ class CompletionResponse(BaseModel):
content: str
stop_reason: StopReason
logprobs: Optional[List[TokenLogProbs]] = None
usage: Optional[UsageStatistics] = None


@json_schema_type
Expand All @@ -213,6 +221,7 @@ class CompletionResponseStreamChunk(BaseModel):
delta: str
stop_reason: Optional[StopReason] = None
logprobs: Optional[List[TokenLogProbs]] = None
usage: Optional[UsageStatistics] = None


@json_schema_type
Expand Down Expand Up @@ -252,6 +261,7 @@ class ChatCompletionResponseStreamChunk(BaseModel):
"""SSE-stream of these events."""

event: ChatCompletionResponseEvent
usage: Optional[UsageStatistics] = None


@json_schema_type
Expand All @@ -260,6 +270,7 @@ class ChatCompletionResponse(BaseModel):

completion_message: CompletionMessage
logprobs: Optional[List[TokenLogProbs]] = None
usage: Optional[UsageStatistics] = None


@json_schema_type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
CrossAttentionTransformer,
)
from llama_models.sku_list import resolve_model

from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
from pydantic import BaseModel

Expand All @@ -47,7 +46,6 @@
ResponseFormat,
ResponseFormatType,
)

from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.providers.utils.inference.prompt_adapter import (
ChatCompletionRequestWithRawContent,
Expand Down Expand Up @@ -78,6 +76,7 @@ class TokenResult(BaseModel):
token: int
text: str
logprobs: Optional[List[float]] = None
input_token_count: Optional[int] = None


class Llama:
Expand Down Expand Up @@ -348,6 +347,7 @@ def generate(
if logprobs
else None
),
input_token_count=len(model_input.tokens),
)

prev_pos = cur_pos
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
ResponseFormat,
TokenLogProbs,
ToolChoice,
UsageStatistics,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.providers.datatypes import ModelsProtocolPrivate
Expand Down Expand Up @@ -168,8 +169,14 @@ async def completion(
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
def impl():
stop_reason = None
input_token_count = 0
output_token_count = 0
usage_statistics = None

for token_result in self.generator.completion(request):
if input_token_count == 0:
input_token_count = token_result.input_token_count
output_token_count += len(token_result.token)
if token_result.text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
text = ""
Expand All @@ -191,17 +198,29 @@ def impl():
}
)
]
else:
usage_statistics = UsageStatistics(
prompt_tokens=input_token_count,
completion_tokens=output_token_count,
total_tokens=input_token_count + output_token_count,
)

yield CompletionResponseStreamChunk(
delta=text,
stop_reason=stop_reason,
logprobs=logprobs if request.logprobs else None,
usage=usage_statistics,
)

if stop_reason is None:
yield CompletionResponseStreamChunk(
delta="",
stop_reason=StopReason.out_of_tokens,
usage=UsageStatistics(
prompt_tokens=input_token_count,
completion_tokens=output_token_count,
total_tokens=input_token_count + output_token_count,
),
)

if self.config.create_distributed_process_group:
Expand All @@ -221,7 +240,10 @@ def impl():
stop_reason = None

tokenizer = self.generator.formatter.tokenizer
input_token_count = 0
for token_result in self.generator.completion(request):
if input_token_count == 0:
input_token_count = token_result.input_token_count
tokens.append(token_result.token)
if token_result.text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
Expand All @@ -242,7 +264,7 @@ def impl():
if stop_reason is None:
stop_reason = StopReason.out_of_tokens

content = self.generator.formatter.tokenizer.decode(tokens)
content = tokenizer.decode(tokens)
if content.endswith("<|eot_id|>"):
content = content[: -len("<|eot_id|>")]
elif content.endswith("<|eom_id|>"):
Expand All @@ -251,6 +273,11 @@ def impl():
content=content,
stop_reason=stop_reason,
logprobs=logprobs if request.logprobs else None,
usage_statistics=UsageStatistics(
prompt_tokens=input_token_count,
completion_tokens=len(tokens),
total_tokens=input_token_count + len(tokens),
),
)

if self.config.create_distributed_process_group:
Expand Down
39 changes: 34 additions & 5 deletions llama_stack/providers/utils/inference/openai_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
TopKSamplingStrategy,
TopPSamplingStrategy,
)

from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import StopReason
from pydantic import BaseModel
Expand All @@ -24,7 +23,6 @@
ToolCallDelta,
ToolCallParseStatus,
)

from llama_stack.apis.inference import (
ChatCompletionResponse,
ChatCompletionResponseEvent,
Expand All @@ -35,8 +33,8 @@
CompletionResponseStreamChunk,
Message,
TokenLogProbs,
UsageStatistics,
)

from llama_stack.providers.utils.inference.prompt_adapter import (
convert_image_content_to_url,
)
Expand All @@ -63,8 +61,15 @@ class OpenAICompatCompletionChoice(BaseModel):
logprobs: Optional[OpenAICompatLogprobs] = None


class OpenAICompatCompletionUsage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int


class OpenAICompatCompletionResponse(BaseModel):
choices: List[OpenAICompatCompletionChoice]
usage: Optional[OpenAICompatCompletionUsage] = None


def get_sampling_strategy_options(params: SamplingParams) -> dict:
Expand Down Expand Up @@ -124,28 +129,45 @@ def convert_openai_completion_logprobs(
return [TokenLogProbs(logprobs_by_token=x) for x in logprobs.top_logprobs]


def get_usage_statistics(
response: OpenAICompatCompletionResponse,
) -> Optional[UsageStatistics]:
if response.usage:
return UsageStatistics(
prompt_tokens=response.usage.prompt_tokens,
completion_tokens=response.usage.completion_tokens,
total_tokens=response.usage.total_tokens,
)
return None


def process_completion_response(
response: OpenAICompatCompletionResponse, formatter: ChatFormat
) -> CompletionResponse:
choice = response.choices[0]
usage_statistics = get_usage_statistics(response)

# drop suffix <eot_id> if present and return stop reason as end of turn
if choice.text.endswith("<|eot_id|>"):
return CompletionResponse(
stop_reason=StopReason.end_of_turn,
content=choice.text[: -len("<|eot_id|>")],
logprobs=convert_openai_completion_logprobs(choice.logprobs),
usage=usage_statistics,
)
# drop suffix <eom_id> if present and return stop reason as end of message
if choice.text.endswith("<|eom_id|>"):
return CompletionResponse(
stop_reason=StopReason.end_of_message,
content=choice.text[: -len("<|eom_id|>")],
logprobs=convert_openai_completion_logprobs(choice.logprobs),
usage=usage_statistics,
)
return CompletionResponse(
stop_reason=get_stop_reason(choice.finish_reason),
content=choice.text,
logprobs=convert_openai_completion_logprobs(choice.logprobs),
usage=usage_statistics,
)


Expand All @@ -164,17 +186,21 @@ def process_chat_completion_response(
tool_calls=raw_message.tool_calls,
),
logprobs=None,
usage=get_usage_statistics(response),
)


async def process_completion_stream_response(
stream: AsyncGenerator[OpenAICompatCompletionResponse, None], formatter: ChatFormat
) -> AsyncGenerator:
stop_reason = None
usage_statistics = None

async for chunk in stream:
choice = chunk.choices[0]
finish_reason = choice.finish_reason
# usage statistics are only available in the final chunk
usage_statistics = get_usage_statistics(chunk)

text = text_from_choice(choice)
if text == "<|eot_id|>":
Expand All @@ -200,6 +226,7 @@ async def process_completion_stream_response(
yield CompletionResponseStreamChunk(
delta="",
stop_reason=stop_reason,
usage=usage_statistics,
)


Expand All @@ -216,10 +243,11 @@ async def process_chat_completion_stream_response(
buffer = ""
ipython = False
stop_reason = None

usage_statistics = None
async for chunk in stream:
choice = chunk.choices[0]
finish_reason = choice.finish_reason
usage_statistics = get_usage_statistics(chunk)

if finish_reason:
if stop_reason is None and finish_reason in ["stop", "eos", "eos_token"]:
Expand Down Expand Up @@ -313,7 +341,8 @@ async def process_chat_completion_stream_response(
event_type=ChatCompletionResponseEventType.complete,
delta=TextDelta(text=""),
stop_reason=stop_reason,
)
),
usage=usage_statistics,
)


Expand Down