Skip to content

Commit 77a5273

Browse files
committed
Support sys_prompt behavior in inference
1 parent 7558678 commit 77a5273

File tree

23 files changed

+218
-45
lines changed

23 files changed

+218
-45
lines changed

llama_stack/apis/inference/inference.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -308,14 +308,46 @@ class CompletionResponseStreamChunk(BaseModel):
308308
logprobs: Optional[List[TokenLogProbs]] = None
309309

310310

311+
@json_schema_type
312+
class SystemMessageBehavior(Enum):
313+
"""Config for how to override the default system prompt.
314+
315+
:cvar append: Appends the provided system message to the default system prompt:
316+
https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_2/#-function-definitions-in-the-system-prompt-
317+
:cvar replace: Replaces the default system prompt with the provided system message. The system message can include the string
318+
'{{function_definitions}}' to indicate where the function definitions should be inserted.
319+
"""
320+
321+
append = "append"
322+
replace = "replace"
323+
324+
325+
@json_schema_type
326+
class ToolConfig(BaseModel):
327+
"""Configuration for tool use.
328+
329+
:param tool_choice: (Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto.
330+
:param tool_prompt_format: (Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model.
331+
- `ToolPromptFormat.json`: The tool calls are formatted as a JSON object.
332+
- `ToolPromptFormat.function_tag`: The tool calls are enclosed in a <function=function_name> tag.
333+
- `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls.
334+
"""
335+
336+
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
337+
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None)
338+
system_message_behavior: SystemMessageBehavior = Field(default=SystemMessageBehavior.append)
339+
340+
311341
# This is an internally used class
342+
@json_schema_type
312343
class ChatCompletionRequest(BaseModel):
313344
model: str
314345
messages: List[Message]
315346
sampling_params: Optional[SamplingParams] = SamplingParams()
347+
316348
tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
317-
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
318-
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None)
349+
tool_config: Optional[ToolConfig] = Field(default_factory=ToolConfig)
350+
319351
response_format: Optional[ResponseFormat] = None
320352
stream: Optional[bool] = False
321353
logprobs: Optional[LogProbConfig] = None
@@ -404,6 +436,7 @@ async def chat_completion(
404436
response_format: Optional[ResponseFormat] = None,
405437
stream: Optional[bool] = False,
406438
logprobs: Optional[LogProbConfig] = None,
439+
tool_config: Optional[ToolConfig] = None,
407440
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
408441
"""Generate a chat completion for the given messages using the specified model.
409442
@@ -412,15 +445,20 @@ async def chat_completion(
412445
:param sampling_params: Parameters to control the sampling strategy
413446
:param tools: (Optional) List of tool definitions available to the model
414447
:param tool_choice: (Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto.
448+
.. deprecated::
449+
Use tool_config instead.
415450
:param tool_prompt_format: (Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model.
416451
- `ToolPromptFormat.json`: The tool calls are formatted as a JSON object.
417452
- `ToolPromptFormat.function_tag`: The tool calls are enclosed in a <function=function_name> tag.
418453
- `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls.
454+
.. deprecated::
455+
Use tool_config instead.
419456
:param response_format: (Optional) Grammar specification for guided (structured) decoding. There are two options:
420457
- `ResponseFormat.json_schema`: The grammar is a JSON schema. Most providers support this format.
421458
- `ResponseFormat.grammar`: The grammar is a BNF grammar. This format is more flexible, but not all providers support it.
422459
:param stream: (Optional) If True, generate an SSE event stream of the response. Defaults to False.
423460
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
461+
:param tool_config: (Optional) Configuration for tool use.
424462
:returns: If stream=False, returns a ChatCompletionResponse with the full completion.
425463
If stream=True, returns an SSE event stream of ChatCompletionResponseStreamChunk
426464
"""

llama_stack/distribution/routers/routers.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
ResponseFormat,
2525
SamplingParams,
2626
ToolChoice,
27+
ToolConfig,
2728
ToolDefinition,
2829
ToolPromptFormat,
2930
)
@@ -132,12 +133,23 @@ async def chat_completion(
132133
tool_prompt_format: Optional[ToolPromptFormat] = None,
133134
stream: Optional[bool] = False,
134135
logprobs: Optional[LogProbConfig] = None,
136+
tool_config: Optional[ToolConfig] = None,
135137
) -> AsyncGenerator:
136138
model = await self.routing_table.get_model(model_id)
137139
if model is None:
138140
raise ValueError(f"Model '{model_id}' not found")
139141
if model.model_type == ModelType.embedding:
140142
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
143+
if tool_config:
144+
if tool_choice != tool_config.tool_choice:
145+
raise ValueError("tool_choice and tool_config.tool_choice must match")
146+
if tool_prompt_format != tool_config.tool_prompt_format:
147+
raise ValueError("tool_prompt_format and tool_config.tool_prompt_format must match")
148+
else:
149+
tool_config = ToolConfig(
150+
tool_choice=tool_choice,
151+
tool_prompt_format=tool_prompt_format,
152+
)
141153
params = dict(
142154
model_id=model_id,
143155
messages=messages,
@@ -148,6 +160,7 @@ async def chat_completion(
148160
response_format=response_format,
149161
stream=stream,
150162
logprobs=logprobs,
163+
tool_config=tool_config,
151164
)
152165
provider = self.routing_table.get_provider_impl(model_id)
153166
if stream:

llama_stack/providers/inline/inference/meta_reference/generation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,7 @@ def chat_completion(
400400
yield from self.generate(
401401
model_input=self.formatter.encode_dialog_prompt(
402402
request.messages,
403-
request.tool_prompt_format,
403+
request.tool_config.tool_prompt_format,
404404
),
405405
max_gen_len=max_gen_len,
406406
temperature=temperature,

llama_stack/providers/inline/inference/meta_reference/inference.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
ResponseFormat,
3939
TokenLogProbs,
4040
ToolChoice,
41+
ToolConfig,
4142
)
4243
from llama_stack.apis.models import Model, ModelType
4344
from llama_stack.providers.datatypes import ModelsProtocolPrivate
@@ -252,6 +253,7 @@ async def chat_completion(
252253
tool_prompt_format: Optional[ToolPromptFormat] = None,
253254
stream: Optional[bool] = False,
254255
logprobs: Optional[LogProbConfig] = None,
256+
tool_config: Optional[ToolConfig] = None,
255257
) -> AsyncGenerator:
256258
if logprobs:
257259
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
@@ -262,11 +264,10 @@ async def chat_completion(
262264
messages=messages,
263265
sampling_params=sampling_params,
264266
tools=tools or [],
265-
tool_choice=tool_choice,
266-
tool_prompt_format=tool_prompt_format,
267267
response_format=response_format,
268268
stream=stream,
269269
logprobs=logprobs,
270+
tool_config=tool_config,
270271
)
271272
self.check_model(request)
272273

llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
ToolChoice,
1818
ToolDefinition,
1919
ToolPromptFormat,
20+
ToolConfig,
2021
)
2122
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
2223
from llama_stack.providers.utils.inference.embedding_mixin import (
@@ -71,5 +72,6 @@ async def chat_completion(
7172
tool_prompt_format: Optional[ToolPromptFormat] = None,
7273
stream: Optional[bool] = False,
7374
logprobs: Optional[LogProbConfig] = None,
75+
tool_config: Optional[ToolConfig] = None,
7476
) -> AsyncGenerator:
7577
raise ValueError("Sentence transformers don't support chat completion")

llama_stack/providers/inline/inference/vllm/vllm.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
ResponseFormat,
3131
SamplingParams,
3232
ToolChoice,
33+
ToolConfig,
3334
ToolDefinition,
3435
ToolPromptFormat,
3536
)
@@ -159,6 +160,7 @@ async def chat_completion(
159160
response_format: Optional[ResponseFormat] = None,
160161
stream: Optional[bool] = False,
161162
logprobs: Optional[LogProbConfig] = None,
163+
tool_config: Optional[ToolConfig] = None,
162164
) -> ChatCompletionResponse | ChatCompletionResponseStreamChunk:
163165
assert self.engine is not None
164166

@@ -167,10 +169,9 @@ async def chat_completion(
167169
messages=messages,
168170
sampling_params=sampling_params,
169171
tools=tools or [],
170-
tool_choice=tool_choice,
171-
tool_prompt_format=tool_prompt_format,
172172
stream=stream,
173173
logprobs=logprobs,
174+
tool_config=tool_config,
174175
)
175176

176177
log.info("Sampling params: %s", sampling_params)

llama_stack/providers/remote/inference/bedrock/bedrock.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
ResponseFormat,
2525
SamplingParams,
2626
ToolChoice,
27+
ToolConfig,
2728
ToolDefinition,
2829
ToolPromptFormat,
2930
)
@@ -102,18 +103,18 @@ async def chat_completion(
102103
tool_prompt_format: Optional[ToolPromptFormat] = None,
103104
stream: Optional[bool] = False,
104105
logprobs: Optional[LogProbConfig] = None,
106+
tool_config: Optional[ToolConfig] = None,
105107
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
106108
model = await self.model_store.get_model(model_id)
107109
request = ChatCompletionRequest(
108110
model=model.provider_resource_id,
109111
messages=messages,
110112
sampling_params=sampling_params,
111113
tools=tools or [],
112-
tool_choice=tool_choice,
113-
tool_prompt_format=tool_prompt_format,
114114
response_format=response_format,
115115
stream=stream,
116116
logprobs=logprobs,
117+
tool_config=tool_config,
117118
)
118119

119120
if stream:

llama_stack/providers/remote/inference/cerebras/cerebras.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
ResponseFormat,
2525
SamplingParams,
2626
ToolChoice,
27+
ToolConfig,
2728
ToolDefinition,
2829
ToolPromptFormat,
2930
)
@@ -128,6 +129,7 @@ async def chat_completion(
128129
response_format: Optional[ResponseFormat] = None,
129130
stream: Optional[bool] = False,
130131
logprobs: Optional[LogProbConfig] = None,
132+
tool_config: Optional[ToolConfig] = None,
131133
) -> AsyncGenerator:
132134
model = await self.model_store.get_model(model_id)
133135
request = ChatCompletionRequest(
@@ -140,6 +142,7 @@ async def chat_completion(
140142
response_format=response_format,
141143
stream=stream,
142144
logprobs=logprobs,
145+
tool_config=tool_config,
143146
)
144147

145148
if stream:

llama_stack/providers/remote/inference/databricks/databricks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,16 +89,16 @@ async def chat_completion(
8989
tool_prompt_format: Optional[ToolPromptFormat] = None,
9090
stream: Optional[bool] = False,
9191
logprobs: Optional[LogProbConfig] = None,
92+
tool_config: Optional[ToolConfig] = None,
9293
) -> AsyncGenerator:
9394
request = ChatCompletionRequest(
9495
model=model,
9596
messages=messages,
9697
sampling_params=sampling_params,
9798
tools=tools or [],
98-
tool_choice=tool_choice,
99-
tool_prompt_format=tool_prompt_format,
10099
stream=stream,
101100
logprobs=logprobs,
101+
tool_config=tool_config,
102102
)
103103

104104
client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)

llama_stack/providers/remote/inference/fireworks/fireworks.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
ResponseFormatType,
2626
SamplingParams,
2727
ToolChoice,
28+
ToolConfig,
2829
ToolDefinition,
2930
ToolPromptFormat,
3031
)
@@ -204,18 +205,18 @@ async def chat_completion(
204205
response_format: Optional[ResponseFormat] = None,
205206
stream: Optional[bool] = False,
206207
logprobs: Optional[LogProbConfig] = None,
208+
tool_config: Optional[ToolConfig] = None,
207209
) -> AsyncGenerator:
208210
model = await self.model_store.get_model(model_id)
209211
request = ChatCompletionRequest(
210212
model=model.provider_resource_id,
211213
messages=messages,
212214
sampling_params=sampling_params,
213215
tools=tools or [],
214-
tool_choice=tool_choice,
215-
tool_prompt_format=tool_prompt_format,
216216
response_format=response_format,
217217
stream=stream,
218218
logprobs=logprobs,
219+
tool_config=tool_config,
219220
)
220221

221222
if stream:

llama_stack/providers/remote/inference/groq/groq.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ async def chat_completion(
9999
tool_prompt_format: Optional[ToolPromptFormat] = None,
100100
stream: Optional[bool] = False,
101101
logprobs: Optional[LogProbConfig] = None,
102+
tool_config: Optional[ToolConfig] = None,
102103
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
103104
model_id = self.get_provider_model_id(model_id)
104105
if model_id == "llama-3.2-3b-preview":
@@ -115,10 +116,9 @@ async def chat_completion(
115116
sampling_params=sampling_params,
116117
response_format=response_format,
117118
tools=tools,
118-
tool_choice=tool_choice,
119-
tool_prompt_format=tool_prompt_format,
120119
stream=stream,
121120
logprobs=logprobs,
121+
tool_config=tool_config,
122122
)
123123
)
124124

llama_stack/providers/remote/inference/groq/groq_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def convert_chat_completion_request(
7979
# so we exclude it for now
8080
warnings.warn("repetition_penalty is not supported")
8181

82-
if request.tool_prompt_format != ToolPromptFormat.json:
82+
if request.tool_config.tool_prompt_format != ToolPromptFormat.json:
8383
warnings.warn("tool_prompt_format is not used by Groq. Ignoring.")
8484

8585
sampling_options = get_sampling_strategy_options(request.sampling_params)
@@ -93,7 +93,7 @@ def convert_chat_completion_request(
9393
temperature=sampling_options.get("temperature", 1.0),
9494
top_p=sampling_options.get("top_p", 1.0),
9595
tools=[_convert_groq_tool_definition(tool) for tool in request.tools or []],
96-
tool_choice=request.tool_choice.value if request.tool_choice else None,
96+
tool_choice=(request.tool_config.tool_choice.value if request.tool_config.tool_choice else None),
9797
)
9898

9999

llama_stack/providers/remote/inference/nvidia/nvidia.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ async def chat_completion(
171171
tool_prompt_format: Optional[ToolPromptFormat] = None,
172172
stream: Optional[bool] = False,
173173
logprobs: Optional[LogProbConfig] = None,
174+
tool_config: Optional[ToolConfig] = None,
174175
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
175176
if tool_prompt_format:
176177
warnings.warn("tool_prompt_format is not supported by NVIDIA NIM, ignoring")
@@ -184,10 +185,9 @@ async def chat_completion(
184185
sampling_params=sampling_params,
185186
response_format=response_format,
186187
tools=tools,
187-
tool_choice=tool_choice,
188-
tool_prompt_format=tool_prompt_format,
189188
stream=stream,
190189
logprobs=logprobs,
190+
tool_config=tool_config,
191191
),
192192
n=1,
193193
)

llama_stack/providers/remote/inference/nvidia/openai_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -282,9 +282,9 @@ async def convert_chat_completion_request(
282282

283283
if request.tools:
284284
payload.update(tools=[_convert_tooldef_to_openai_tool(tool) for tool in request.tools])
285-
if request.tool_choice:
285+
if request.tool_config.tool_choice:
286286
payload.update(
287-
tool_choice=request.tool_choice.value
287+
tool_choice=request.tool_config.tool_choice.value
288288
) # we cannot include tool_choice w/o tools, server will complain
289289

290290
if request.logprobs:

0 commit comments

Comments
 (0)