Skip to content

Commit 333a0e6

Browse files
authored
Add function calling to briton template (#1107)
1 parent 7e9a2be commit 333a0e6

File tree

2 files changed

+39
-10
lines changed

2 files changed

+39
-10
lines changed

truss/constants.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,10 +108,11 @@
108108
BASE_TRTLLM_REQUIREMENTS = [
109109
"grpcio==1.64.0",
110110
"grpcio-tools==1.64.0",
111-
"transformers==4.43.2",
111+
"transformers==4.44.2",
112112
"truss==0.9.30rc3",
113113
"outlines==0.0.46",
114114
"torch==2.4.0",
115+
"sentencepiece==0.2.0",
115116
]
116117
AUDIO_MODEL_TRTLLM_REQUIREMENTS = [
117118
"--extra-index-url https://pypi.nvidia.com",

truss/templates/trtllm-briton/src/engine.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -189,10 +189,19 @@ async def predict(self, model_input):
189189
channel = grpc.aio.insecure_channel(f"localhost:{BRITON_PORT}")
190190
self._stub = briton_pb2_grpc.BritonStub(channel)
191191

192+
function_calling_schema = None
193+
tools = model_input.get("tools", None)
194+
if tools is not None:
195+
function_calling_schema = {
196+
"anyOf": [create_tool_schema(tool) for tool in tools],
197+
}
198+
192199
prompt = model_input.get("prompt", None)
193200
if prompt is None and "messages" in model_input:
194201
messages = model_input.pop("messages")
195-
prompt = self._tokenizer.apply_chat_template(messages, tokenize=False)
202+
prompt = self._tokenizer.apply_chat_template(
203+
messages, tools=tools, tokenize=False, add_generation_prompt=True
204+
)
196205
if prompt is None or len(prompt) == 0:
197206
raise HTTPException(status_code=400, detail="Prompt cannot be empty.")
198207

@@ -214,8 +223,12 @@ async def predict(self, model_input):
214223
and self._tokenizer.pad_token_id is not None
215224
):
216225
request.pad_id = self._tokenizer.pad_token_id
217-
# Add output schema hash if response_format is provided
218-
schema_hash = self._fsm_cache.add_schema_from_input(model_input)
226+
# Add output schema hash if we're function calling or response_format is provided
227+
schema_hash = (
228+
self._fsm_cache.add_schema(function_calling_schema)
229+
if function_calling_schema is not None
230+
else self._fsm_cache.add_schema_from_input(model_input)
231+
)
219232
if schema_hash is not None:
220233
request.output_schema_hash = schema_hash
221234
set_briton_request_fields_from_model_input(model_input, request)
@@ -248,6 +261,17 @@ async def build_response():
248261
raise HTTPException(status_code=500, detail=f"An error has occurred: {ex}")
249262

250263

264+
def create_tool_schema(tool_json: Dict[str, Any]) -> Dict[str, Any]:
265+
return {
266+
"type": "object",
267+
"properties": {
268+
"name": {"const": tool_json["function"]["name"]},
269+
"parameters": tool_json["function"]["parameters"],
270+
},
271+
"required": ["name", "parameters"],
272+
}
273+
274+
251275
class FsmCache:
252276
def __init__(self, cache_dir: Path, tokenizer: AutoTokenizer):
253277
self._cache_dir = cache_dir
@@ -256,19 +280,23 @@ def __init__(self, cache_dir: Path, tokenizer: AutoTokenizer):
256280
self._cache = set(f.name for f in self._cache_dir.iterdir() if f.is_file())
257281
self._tokenizer = tokenizer
258282

259-
def add_schema_from_input(self, model_input: Dict[str, Any]) -> Optional[str]:
260-
schema = self._extract_schema(model_input)
261-
if schema is None:
262-
return None
283+
def add_schema(self, schema: Dict[str, Any]) -> str:
263284
schema_str = json.dumps(schema)
264285
schema_hash = hashlib.sha256(schema_str.encode()).hexdigest()
265286
if schema_hash not in self._cache:
266-
fsm = self._create_fsm(schema_str)
287+
fsm = self._create_fsm(schema)
267288
(self._cache_dir / schema_hash).write_bytes(fsm.SerializeToString())
268289
self._cache.add(schema_hash)
269290
return schema_hash
270291

271-
def _create_fsm(self, schema: str) -> briton_pb2.StatesToTokens: # type: ignore[name-defined]
292+
def add_schema_from_input(self, model_input: Dict[str, Any]) -> Optional[str]:
293+
schema_hash = None
294+
schema = self._extract_schema(model_input)
295+
if schema is not None:
296+
schema_hash = self.add_schema(schema)
297+
return schema_hash
298+
299+
def _create_fsm(self, schema: Dict[str, Any]) -> briton_pb2.StatesToTokens: # type: ignore[name-defined]
272300
outlines_tokenizer = TransformerTokenizer(self._tokenizer)
273301
logits_processor = JSONLogitsProcessor(schema, outlines_tokenizer)
274302
guide = logits_processor.fsm

0 commit comments

Comments
 (0)