Skip to content

Commit

Permalink
BT-11405 json schema structured output (#1095)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdubayah authored Aug 26, 2024
1 parent fc8d635 commit aa35e2e
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 21 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "truss"
version = "0.9.30rc3"
version = "0.9.30rc4"
description = "A seamless bridge from model development to model delivery"
license = "MIT"
readme = "README.md"
Expand Down
6 changes: 4 additions & 2 deletions truss/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,15 @@

REGISTRY_BUILD_SECRET_PREFIX = "DOCKER_REGISTRY_"

TRTLLM_BASE_IMAGE = "baseten/briton-server:5fa9436e_v0.0.5"
TRTLLM_BASE_IMAGE = "baseten/briton-server:5fa9436e_v0.0.8"
TRTLLM_PYTHON_EXECUTABLE = "/usr/bin/python3"
BASE_TRTLLM_REQUIREMENTS = [
"grpcio==1.64.0",
"grpcio-tools==1.64.0",
"transformers==4.43.2",
"truss==0.9.27rc2",
"truss==0.9.30rc1",
"outlines==0.0.46",
"torch==2.4.0",
]
AUDIO_MODEL_TRTLLM_REQUIREMENTS = [
"--extra-index-url https://pypi.nvidia.com",
Expand Down
46 changes: 32 additions & 14 deletions truss/templates/trtllm-briton/packages/briton_pb2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: briton.proto
# Protobuf Python Version: 5.26.1
# Protobuf Python Version: 4.25.1
"""Generated protocol buffer code."""

from google.protobuf import descriptor as _descriptor
Expand All @@ -18,22 +18,40 @@


DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n\x0c\x62riton.proto\x12\x06\x62riton"\xb0\x04\n\x10InferenceRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\x03\x12\x12\n\ninput_text\x18\x02 \x01(\t\x12\x11\n\tinput_ids\x18\x03 \x03(\x05\x12\x1f\n\x12request_output_len\x18\x05 \x01(\rH\x00\x88\x01\x01\x12\x13\n\x06\x65nd_id\x18\x06 \x01(\rH\x01\x88\x01\x01\x12\x13\n\x06pad_id\x18\x07 \x01(\rH\x02\x88\x01\x01\x12\x17\n\nbeam_width\x18\n \x01(\rH\x03\x88\x01\x01\x12\x18\n\x0btemperature\x18\x0b \x01(\x02H\x04\x88\x01\x01\x12\x1a\n\rruntime_top_k\x18\x0c \x01(\rH\x05\x88\x01\x01\x12\x1a\n\rruntime_top_p\x18\r \x01(\x02H\x06\x88\x01\x01\x12\x18\n\x0blen_penalty\x18\x0e \x01(\x02H\x07\x88\x01\x01\x12\x1f\n\x12repetition_penalty\x18\x0f \x01(\x02H\x08\x88\x01\x01\x12\x1d\n\x10presence_penalty\x18\x10 \x01(\x02H\t\x88\x01\x01\x12\x11\n\tbad_words\x18\x11 \x03(\t\x12\x12\n\nstop_words\x18\x12 \x03(\tB\x15\n\x13_request_output_lenB\t\n\x07_end_idB\t\n\x07_pad_idB\r\n\x0b_beam_widthB\x0e\n\x0c_temperatureB\x10\n\x0e_runtime_top_kB\x10\n\x0e_runtime_top_pB\x0e\n\x0c_len_penaltyB\x15\n\x13_repetition_penaltyB\x13\n\x11_presence_penalty"R\n\x13InferenceAnswerPart\x12\x12\n\nrequest_id\x18\x01 \x01(\x03\x12\x13\n\x0boutput_text\x18\x02 \x01(\t\x12\x12\n\noutput_ids\x18\x03 \x03(\x05"\xfa\x03\n\x0c\x42ritonConfig\x12\x13\n\x0b\x65ngine_path\x18\x01 \x01(\t\x12\x14\n\x0chf_tokenizer\x18\x02 \x01(\t\x12N\n\x16\x62\x61tch_scheduler_policy\x18\x05 \x01(\x0e\x32).briton.BritonConfig.BatchSchedulerPolicyH\x00\x88\x01\x01\x12\x1f\n\x12\x65nable_trt_overlap\x18\x06 \x01(\x08H\x01\x88\x01\x01\x12)\n\x1cmax_tokens_in_paged_kv_cache\x18\n \x01(\x04H\x02\x88\x01\x01\x12+\n\x1ekv_cache_free_gpu_mem_fraction\x18\x0b \x01(\x02H\x03\x88\x01\x01\x12!\n\x14medusa_decoding_mode\x18\x0c \x01(\x08H\x04\x88\x01\x01"D\n\x14\x42\x61tchSchedulerPolicy\x12\x13\n\x0fMAX_UTILIZATION\x10\x00\x12\x17\n\x13GUARANTEED_NO_EVICT\x10\x01\x42\x19\n\x17_batch_scheduler_policyB\x15\n\x13_enable_trt_overlapB\x1f\n\x1d_max_tokens_in_paged_kv_cacheB!\n\x1f_kv_cache_free_gpu_mem_fractionB\x17\n\x15_medusa_decoding_mode2L\n\x06\x42riton\x12\x42\n\x05Infer\x12\x18.briton.InferenceRequest\x1a\x1b.briton.InferenceAnswerPart"\x00\x30\x01\x62\x06proto3'
b'\n\x0c\x62riton.proto\x12\x06\x62riton"r\n\x06Tensor\x12#\n\x05shape\x18\x01 \x01(\x0b\x32\x14.briton.Tensor.Shape\x12\x1f\n\x05\x64type\x18\x02 \x01(\x0e\x32\x10.briton.DataType\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\x1a\x14\n\x05Shape\x12\x0b\n\x03\x64im\x18\x01 \x03(\x03"\xb4\x06\n\x10InferenceRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\x03\x12\x12\n\ninput_text\x18\x02 \x01(\t\x12\x11\n\tinput_ids\x18\x03 \x03(\x05\x12\x1f\n\x12request_output_len\x18\x05 \x01(\rH\x00\x88\x01\x01\x12\x13\n\x06\x65nd_id\x18\x06 \x01(\rH\x01\x88\x01\x01\x12\x13\n\x06pad_id\x18\x07 \x01(\rH\x02\x88\x01\x01\x12\x17\n\nbeam_width\x18\n \x01(\rH\x03\x88\x01\x01\x12\x18\n\x0btemperature\x18\x0b \x01(\x02H\x04\x88\x01\x01\x12\x1a\n\rruntime_top_k\x18\x0c \x01(\rH\x05\x88\x01\x01\x12\x1a\n\rruntime_top_p\x18\r \x01(\x02H\x06\x88\x01\x01\x12\x18\n\x0blen_penalty\x18\x0e \x01(\x02H\x07\x88\x01\x01\x12\x1f\n\x12repetition_penalty\x18\x0f \x01(\x02H\x08\x88\x01\x01\x12\x1d\n\x10presence_penalty\x18\x10 \x01(\x02H\t\x88\x01\x01\x12\x11\n\tbad_words\x18\x11 \x03(\t\x12\x12\n\nstop_words\x18\x12 \x03(\t\x12\x19\n\x0clora_task_id\x18\x13 \x01(\x04H\n\x88\x01\x01\x12)\n\x0clora_weights\x18\x14 \x01(\x0b\x32\x0e.briton.TensorH\x0b\x88\x01\x01\x12(\n\x0blora_config\x18\x15 \x01(\x0b\x32\x0e.briton.TensorH\x0c\x88\x01\x01\x12\x18\n\x0brandom_seed\x18\x16 \x01(\x03H\r\x88\x01\x01\x12\x1f\n\x12output_schema_hash\x18\x17 \x01(\tH\x0e\x88\x01\x01\x42\x15\n\x13_request_output_lenB\t\n\x07_end_idB\t\n\x07_pad_idB\r\n\x0b_beam_widthB\x0e\n\x0c_temperatureB\x10\n\x0e_runtime_top_kB\x10\n\x0e_runtime_top_pB\x0e\n\x0c_len_penaltyB\x15\n\x13_repetition_penaltyB\x13\n\x11_presence_penaltyB\x0f\n\r_lora_task_idB\x0f\n\r_lora_weightsB\x0e\n\x0c_lora_configB\x0e\n\x0c_random_seedB\x15\n\x13_output_schema_hash"R\n\x13InferenceAnswerPart\x12\x12\n\nrequest_id\x18\x01 \x01(\x03\x12\x13\n\x0boutput_text\x18\x02 \x01(\t\x12\x12\n\noutput_ids\x18\x03 \x03(\x05"\xa6\x08\n\x0c\x42ritonConfig\x12\x13\n\x0b\x65ngine_path\x18\x01 \x01(\t\x12\x14\n\x0chf_tokenizer\x18\x02 \x01(\t\x12N\n\x16\x62\x61tch_scheduler_policy\x18\x05 \x01(\x0e\x32).briton.BritonConfig.BatchSchedulerPolicyH\x00\x88\x01\x01\x12\x1f\n\x12\x65nable_trt_overlap\x18\x06 \x01(\x08H\x01\x88\x01\x01\x12)\n\x1cmax_tokens_in_paged_kv_cache\x18\n \x01(\x04H\x02\x88\x01\x01\x12+\n\x1ekv_cache_free_gpu_mem_fraction\x18\x0b \x01(\x02H\x03\x88\x01\x01\x12!\n\x14medusa_decoding_mode\x18\x0c \x01(\x08H\x04\x88\x01\x01\x12#\n\x16\x65nable_chunked_context\x18\r \x01(\x08H\x05\x88\x01\x01\x12"\n\x15\x65nable_kv_cache_reuse\x18\x0e \x01(\x08H\x06\x88\x01\x01\x12\'\n\x1akv_cache_host_memory_bytes\x18\x0f \x01(\x04H\x07\x88\x01\x01\x12(\n\x1blora_cache_max_adapter_size\x18\x10 \x01(\x04H\x08\x88\x01\x01\x12,\n\x1flora_cache_optimal_adapter_size\x18\x11 \x01(\x04H\t\x88\x01\x01\x12+\n\x1elora_cache_gpu_memory_fraction\x18\x12 \x01(\x02H\n\x88\x01\x01\x12)\n\x1clora_cache_host_memory_bytes\x18\x13 \x01(\x04H\x0b\x88\x01\x01\x12\x1a\n\rfsm_cache_dir\x18\x14 \x01(\tH\x0c\x88\x01\x01"D\n\x14\x42\x61tchSchedulerPolicy\x12\x13\n\x0fMAX_UTILIZATION\x10\x00\x12\x17\n\x13GUARANTEED_NO_EVICT\x10\x01\x42\x19\n\x17_batch_scheduler_policyB\x15\n\x13_enable_trt_overlapB\x1f\n\x1d_max_tokens_in_paged_kv_cacheB!\n\x1f_kv_cache_free_gpu_mem_fractionB\x17\n\x15_medusa_decoding_modeB\x19\n\x17_enable_chunked_contextB\x18\n\x16_enable_kv_cache_reuseB\x1d\n\x1b_kv_cache_host_memory_bytesB\x1e\n\x1c_lora_cache_max_adapter_sizeB"\n _lora_cache_optimal_adapter_sizeB!\n\x1f_lora_cache_gpu_memory_fractionB\x1f\n\x1d_lora_cache_host_memory_bytesB\x10\n\x0e_fsm_cache_dir"\x98\x01\n\x10TokenToNextState\x12K\n\x13token_to_next_state\x18\x01 \x03(\x0b\x32..briton.TokenToNextState.TokenToNextStateEntry\x1a\x37\n\x15TokenToNextStateEntry\x12\x0b\n\x03key\x18\x01 \x01(\x05\x12\r\n\x05value\x18\x02 \x01(\x05:\x02\x38\x01"\xfb\x01\n\x0eStatesToTokens\x12\x44\n\x10states_to_tokens\x18\x01 \x03(\x0b\x32*.briton.StatesToTokens.StatesToTokensEntry\x12\x17\n\nvocab_size\x18\x02 \x01(\x05H\x00\x88\x01\x01\x12\x19\n\x0c\x65os_token_id\x18\x03 \x01(\x05H\x01\x88\x01\x01\x1aO\n\x13StatesToTokensEntry\x12\x0b\n\x03key\x18\x01 \x01(\x05\x12\'\n\x05value\x18\x02 \x01(\x0b\x32\x18.briton.TokenToNextState:\x02\x38\x01\x42\r\n\x0b_vocab_sizeB\x0f\n\r_eos_token_id*\xa8\x01\n\x08\x44\x61taType\x12\x0e\n\nDT_INVALID\x10\x00\x12\x0b\n\x07\x44T_INT4\x10\x01\x12\x0b\n\x07\x44T_INT8\x10\x02\x12\x0c\n\x08\x44T_UINT8\x10\x03\x12\x0c\n\x08\x44T_INT32\x10\x04\x12\x0c\n\x08\x44T_INT64\x10\x05\x12\x0e\n\nDT_FLOAT16\x10\n\x12\x0f\n\x0b\x44T_BFLOAT16\x10\x0b\x12\x0e\n\nDT_FLOAT32\x10\x0c\x12\n\n\x06\x44T_FP8\x10\r\x12\x0b\n\x07\x44T_BOOL\x10\x14\x32L\n\x06\x42riton\x12\x42\n\x05Infer\x12\x18.briton.InferenceRequest\x1a\x1b.briton.InferenceAnswerPart"\x00\x30\x01\x62\x06proto3'
)

_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "briton_pb2", _globals)
if not _descriptor._USE_C_DESCRIPTORS:
DESCRIPTOR._loaded_options = None
_globals["_INFERENCEREQUEST"]._serialized_start = 25
_globals["_INFERENCEREQUEST"]._serialized_end = 585
_globals["_INFERENCEANSWERPART"]._serialized_start = 587
_globals["_INFERENCEANSWERPART"]._serialized_end = 669
_globals["_BRITONCONFIG"]._serialized_start = 672
_globals["_BRITONCONFIG"]._serialized_end = 1178
_globals["_BRITONCONFIG_BATCHSCHEDULERPOLICY"]._serialized_start = 967
_globals["_BRITONCONFIG_BATCHSCHEDULERPOLICY"]._serialized_end = 1035
_globals["_BRITON"]._serialized_start = 1180
_globals["_BRITON"]._serialized_end = 1256
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_globals["_TOKENTONEXTSTATE_TOKENTONEXTSTATEENTRY"]._options = None
_globals["_TOKENTONEXTSTATE_TOKENTONEXTSTATEENTRY"]._serialized_options = b"8\001"
_globals["_STATESTOTOKENS_STATESTOTOKENSENTRY"]._options = None
_globals["_STATESTOTOKENS_STATESTOTOKENSENTRY"]._serialized_options = b"8\001"
_globals["_DATATYPE"]._serialized_start = 2522
_globals["_DATATYPE"]._serialized_end = 2690
_globals["_TENSOR"]._serialized_start = 24
_globals["_TENSOR"]._serialized_end = 138
_globals["_TENSOR_SHAPE"]._serialized_start = 118
_globals["_TENSOR_SHAPE"]._serialized_end = 138
_globals["_INFERENCEREQUEST"]._serialized_start = 141
_globals["_INFERENCEREQUEST"]._serialized_end = 961
_globals["_INFERENCEANSWERPART"]._serialized_start = 963
_globals["_INFERENCEANSWERPART"]._serialized_end = 1045
_globals["_BRITONCONFIG"]._serialized_start = 1048
_globals["_BRITONCONFIG"]._serialized_end = 2110
_globals["_BRITONCONFIG_BATCHSCHEDULERPOLICY"]._serialized_start = 1661
_globals["_BRITONCONFIG_BATCHSCHEDULERPOLICY"]._serialized_end = 1729
_globals["_TOKENTONEXTSTATE"]._serialized_start = 2113
_globals["_TOKENTONEXTSTATE"]._serialized_end = 2265
_globals["_TOKENTONEXTSTATE_TOKENTONEXTSTATEENTRY"]._serialized_start = 2210
_globals["_TOKENTONEXTSTATE_TOKENTONEXTSTATEENTRY"]._serialized_end = 2265
_globals["_STATESTOTOKENS"]._serialized_start = 2268
_globals["_STATESTOTOKENS"]._serialized_end = 2519
_globals["_STATESTOTOKENS_STATESTOTOKENSENTRY"]._serialized_start = 2408
_globals["_STATESTOTOKENS_STATESTOTOKENSENTRY"]._serialized_end = 2487
_globals["_BRITON"]._serialized_start = 2692
_globals["_BRITON"]._serialized_end = 2768
# @@protoc_insertion_point(module_scope)
87 changes: 83 additions & 4 deletions truss/templates/trtllm-briton/src/engine.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
import hashlib
import json
import os
import signal
import socket
import subprocess
import threading
import time
from itertools import count
from pathlib import Path
from typing import Any, Dict, Optional

import briton_pb2
import briton_pb2_grpc
import grpc
from fastapi import HTTPException
from outlines.models.transformers import TransformerTokenizer
from outlines.processors.structured import JSONLogitsProcessor
from transformers import AutoTokenizer
from truss.config.trt_llm import TrussTRTLLMBuildConfiguration
from truss.constants import OPENAI_COMPATIBLE_TAG
Expand All @@ -29,6 +35,9 @@
"runtime_top_p": "runtime_top_p",
}

# Use a directory that can be picked up by baseten-fs
FSM_CACHE_DIR = "/cache/model/fsm_cache"


def is_port_available(port, host="localhost"):
try:
Expand Down Expand Up @@ -98,11 +107,15 @@ def load(self):
self._tokenizer = AutoTokenizer.from_pretrained(
self._tokenizer_repository, token=self._hf_token
)

self._fsm_cache = FsmCache(Path(FSM_CACHE_DIR), self._tokenizer)

# Start engine
config_str = f"""
engine_path: "{self._data_dir.resolve()}"
hf_tokenizer: "{self._tokenizer_repository}"
kv_cache_free_gpu_mem_fraction: {self._kv_cache_free_gpu_mem_fraction}
fsm_cache_dir: "{FSM_CACHE_DIR}"
"""
config_pbtxt_path = (self._data_dir / "briton_config.pbtxt").resolve()
config_pbtxt_path.write_text(config_str)
Expand Down Expand Up @@ -176,6 +189,10 @@ async def predict(self, model_input):
if prompt is None and "messages" in model_input:
messages = model_input.pop("messages")
prompt = self._tokenizer.apply_chat_template(messages, tokenize=False)
if prompt is None or len(prompt) == 0:
raise HTTPException(status_code=400, detail="Prompt cannot be empty.")

self.validate_input(model_input)

request_id = int(str(os.getpid()) + str(next(self._request_id_counter)))
request = briton_pb2.InferenceRequest(
Expand All @@ -193,14 +210,16 @@ async def predict(self, model_input):
and self._tokenizer.pad_token_id is not None
):
request.pad_id = self._tokenizer.pad_token_id
# Add output schema hash if response_format is provided
schema_hash = self._fsm_cache.add_schema_from_input(model_input)
if schema_hash is not None:
request.output_schema_hash = schema_hash
set_briton_request_fields_from_model_input(model_input, request)
for words in ["bad_words", "stop_words"]:
if words in model_input:
for word in model_input[words].split(","):
getattr(request, words).append(word)

self.validate_input(model_input)

resp_iter = self._stub.Infer(request)

async def generate():
Expand All @@ -222,8 +241,68 @@ async def build_response():
if ex.code() == grpc.StatusCode.INVALID_ARGUMENT:
raise HTTPException(status_code=400, detail=ex.details())
except Exception as ex:
print(f"An error has occurred: {ex}")
raise ex
raise HTTPException(status_code=500, detail=f"An error has occurred: {ex}")


class FsmCache:
def __init__(self, cache_dir: Path, tokenizer: AutoTokenizer):
self._cache_dir = cache_dir
if not self._cache_dir.exists():
self._cache_dir.mkdir(parents=True, exist_ok=True)
self._cache = set(f.name for f in self._cache_dir.iterdir() if f.is_file())
self._tokenizer = tokenizer

def add_schema_from_input(self, model_input: Dict[str, Any]) -> Optional[str]:
schema = self._extract_schema(model_input)
if schema is None:
return None
schema_str = json.dumps(schema)
schema_hash = hashlib.sha256(schema_str.encode()).hexdigest()
if schema_hash not in self._cache:
fsm = self._create_fsm(schema_str)
(self._cache_dir / schema_hash).write_bytes(fsm.SerializeToString())
self._cache.add(schema_hash)
return schema_hash

def _create_fsm(self, schema: str) -> briton_pb2.StatesToTokens: # type: ignore[name-defined]
outlines_tokenizer = TransformerTokenizer(self._tokenizer)
logits_processor = JSONLogitsProcessor(schema, outlines_tokenizer)
guide = logits_processor.fsm

states_to_tokens = {}
for state, token_to_next_state in guide.states_to_token_maps.items():
states_to_tokens[state] = briton_pb2.TokenToNextState( # type: ignore[attr-defined]
token_to_next_state=token_to_next_state
)
states_to_tokens_pb = briton_pb2.StatesToTokens( # type: ignore[attr-defined]
states_to_tokens=states_to_tokens,
vocab_size=len(self._tokenizer.vocab),
eos_token_id=self._tokenizer.eos_token_id,
)
return states_to_tokens_pb

@staticmethod
def _extract_schema(model_input: Dict[str, Any]) -> Optional[Dict[str, Any]]:
if "response_format" not in model_input:
return None
response_format = model_input["response_format"]
if "type" not in response_format or response_format["type"] != "json_schema":
raise HTTPException(
status_code=400,
detail='response_format["type"] must be json_schema.',
)
if "json_schema" not in response_format:
raise HTTPException(
status_code=400,
detail='response_format["json_schema"] must be provided.',
)
json_schema = response_format["json_schema"]
if "schema" not in json_schema:
raise HTTPException(
status_code=400,
detail='response_format["json_schema"]["schema"] must be provided.',
)
return json_schema["schema"]


def set_briton_request_fields_from_model_input(model_input, briton_request):
Expand Down

0 comments on commit aa35e2e

Please sign in to comment.