@@ -189,10 +189,19 @@ async def predict(self, model_input):
189
189
channel = grpc .aio .insecure_channel (f"localhost:{ BRITON_PORT } " )
190
190
self ._stub = briton_pb2_grpc .BritonStub (channel )
191
191
192
+ function_calling_schema = None
193
+ tools = model_input .get ("tools" , None )
194
+ if tools is not None :
195
+ function_calling_schema = {
196
+ "anyOf" : [create_tool_schema (tool ) for tool in tools ],
197
+ }
198
+
192
199
prompt = model_input .get ("prompt" , None )
193
200
if prompt is None and "messages" in model_input :
194
201
messages = model_input .pop ("messages" )
195
- prompt = self ._tokenizer .apply_chat_template (messages , tokenize = False )
202
+ prompt = self ._tokenizer .apply_chat_template (
203
+ messages , tools = tools , tokenize = False , add_generation_prompt = True
204
+ )
196
205
if prompt is None or len (prompt ) == 0 :
197
206
raise HTTPException (status_code = 400 , detail = "Prompt cannot be empty." )
198
207
@@ -214,8 +223,12 @@ async def predict(self, model_input):
214
223
and self ._tokenizer .pad_token_id is not None
215
224
):
216
225
request .pad_id = self ._tokenizer .pad_token_id
217
- # Add output schema hash if response_format is provided
218
- schema_hash = self ._fsm_cache .add_schema_from_input (model_input )
226
+ # Add output schema hash if we're function calling or response_format is provided
227
+ schema_hash = (
228
+ self ._fsm_cache .add_schema (function_calling_schema )
229
+ if function_calling_schema is not None
230
+ else self ._fsm_cache .add_schema_from_input (model_input )
231
+ )
219
232
if schema_hash is not None :
220
233
request .output_schema_hash = schema_hash
221
234
set_briton_request_fields_from_model_input (model_input , request )
@@ -248,6 +261,17 @@ async def build_response():
248
261
raise HTTPException (status_code = 500 , detail = f"An error has occurred: { ex } " )
249
262
250
263
264
+ def create_tool_schema (tool_json : Dict [str , Any ]) -> Dict [str , Any ]:
265
+ return {
266
+ "type" : "object" ,
267
+ "properties" : {
268
+ "name" : {"const" : tool_json ["function" ]["name" ]},
269
+ "parameters" : tool_json ["function" ]["parameters" ],
270
+ },
271
+ "required" : ["name" , "parameters" ],
272
+ }
273
+
274
+
251
275
class FsmCache :
252
276
def __init__ (self , cache_dir : Path , tokenizer : AutoTokenizer ):
253
277
self ._cache_dir = cache_dir
@@ -256,19 +280,23 @@ def __init__(self, cache_dir: Path, tokenizer: AutoTokenizer):
256
280
self ._cache = set (f .name for f in self ._cache_dir .iterdir () if f .is_file ())
257
281
self ._tokenizer = tokenizer
258
282
259
- def add_schema_from_input (self , model_input : Dict [str , Any ]) -> Optional [str ]:
260
- schema = self ._extract_schema (model_input )
261
- if schema is None :
262
- return None
283
+ def add_schema (self , schema : Dict [str , Any ]) -> str :
263
284
schema_str = json .dumps (schema )
264
285
schema_hash = hashlib .sha256 (schema_str .encode ()).hexdigest ()
265
286
if schema_hash not in self ._cache :
266
- fsm = self ._create_fsm (schema_str )
287
+ fsm = self ._create_fsm (schema )
267
288
(self ._cache_dir / schema_hash ).write_bytes (fsm .SerializeToString ())
268
289
self ._cache .add (schema_hash )
269
290
return schema_hash
270
291
271
- def _create_fsm (self , schema : str ) -> briton_pb2 .StatesToTokens : # type: ignore[name-defined]
292
+ def add_schema_from_input (self , model_input : Dict [str , Any ]) -> Optional [str ]:
293
+ schema_hash = None
294
+ schema = self ._extract_schema (model_input )
295
+ if schema is not None :
296
+ schema_hash = self .add_schema (schema )
297
+ return schema_hash
298
+
299
+ def _create_fsm (self , schema : Dict [str , Any ]) -> briton_pb2 .StatesToTokens : # type: ignore[name-defined]
272
300
outlines_tokenizer = TransformerTokenizer (self ._tokenizer )
273
301
logits_processor = JSONLogitsProcessor (schema , outlines_tokenizer )
274
302
guide = logits_processor .fsm
0 commit comments