Skip to content

Commit

Permalink
BT-12002 BT-12001 BT-12023 Misc briton fixes (#1125)
Browse files Browse the repository at this point in the history
* BT-12002; BT-12001; BT-12023

* PR feedback
  • Loading branch information
bdubayah authored Sep 6, 2024
1 parent 5e34879 commit aa6a8e6
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 9 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "truss"
version = "0.9.32"
version = "0.9.33rc1"
description = "A seamless bridge from model development to model delivery"
license = "MIT"
readme = "README.md"
Expand Down
50 changes: 42 additions & 8 deletions truss/templates/trtllm-briton/src/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,11 +224,16 @@ async def predict(self, model_input):
):
request.pad_id = self._tokenizer.pad_token_id
# Add output schema hash if we're function calling or response_format is provided
schema_hash = (
self._fsm_cache.add_schema(function_calling_schema)
if function_calling_schema is not None
else self._fsm_cache.add_schema_from_input(model_input)
)
schema_hash = None
try:
schema_hash = (
self._fsm_cache.add_schema(function_calling_schema)
if function_calling_schema is not None
else self._fsm_cache.add_schema_from_input(model_input)
)
# If the input schema is invalid, we should return a 400
except NotImplementedError as ex:
raise HTTPException(status_code=400, detail=str(ex))
if schema_hash is not None:
request.output_schema_hash = schema_hash
set_briton_request_fields_from_model_input(model_input, request)
Expand All @@ -240,23 +245,52 @@ async def predict(self, model_input):
resp_iter = self._stub.Infer(request)

async def generate():
eos_token = (
self._tokenizer.eos_token
if hasattr(self._tokenizer, "eos_token")
else None
)
async for response in resp_iter:
yield response.output_text
if eos_token:
yield response.output_text.removesuffix(eos_token)
else:
yield response.output_text

async def build_response():
eos_token = (
self._tokenizer.eos_token
if hasattr(self._tokenizer, "eos_token")
else None
)
full_text = ""
async for delta in resp_iter:
full_text += delta.output_text
return full_text
if eos_token:
return full_text.removesuffix(eos_token)
else:
return full_text

try:
if model_input.get("stream", True):
return generate()
gen = generate()
first_chunk = await gen.__anext__()

async def generate_after_first_chunk():
yield first_chunk
async for chunk in gen:
yield chunk

return generate_after_first_chunk()
else:
return await build_response()
except grpc.RpcError as ex:
if ex.code() == grpc.StatusCode.INVALID_ARGUMENT:
raise HTTPException(status_code=400, detail=ex.details())
# If the error is another GRPC exception like NotImplemented, we should return a 500
else:
raise HTTPException(
status_code=500, detail=f"An error has occurred: {ex}"
)
except Exception as ex:
raise HTTPException(status_code=500, detail=f"An error has occurred: {ex}")

Expand Down

0 comments on commit aa6a8e6

Please sign in to comment.