Skip to content

Commit c0fd4e0

Browse files
authored
Raise error if encountered in chat completion SSE stream (#2558)
1 parent 64bcff5 commit c0fd4e0

File tree

2 files changed

+37
-3
lines changed

2 files changed

+37
-3
lines changed

src/huggingface_hub/inference/_common.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,12 @@ def _format_chat_completion_stream_output(
350350
# Decode payload
351351
payload = byte_payload.decode("utf-8")
352352
json_payload = json.loads(payload.lstrip("data:").rstrip("/n"))
353+
354+
# Either an error as being returned
355+
if json_payload.get("error") is not None:
356+
raise _parse_text_generation_error(json_payload["error"], json_payload.get("error_type"))
357+
358+
# Or parse token payload
353359
return ChatCompletionStreamOutput.parse_obj_as_instance(json_payload)
354360

355361

tests/test_inference_client.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
hf_hub_download,
4848
)
4949
from huggingface_hub.constants import ALL_INFERENCE_API_FRAMEWORKS, MAIN_INFERENCE_API_FRAMEWORKS
50-
from huggingface_hub.errors import HfHubHTTPError
50+
from huggingface_hub.errors import HfHubHTTPError, ValidationError
5151
from huggingface_hub.inference._client import _open_as_binary
5252
from huggingface_hub.inference._common import (
5353
_stream_chat_completion_response,
@@ -919,7 +919,14 @@ def test_model_and_base_url_mutually_exclusive(self):
919919
InferenceClient(model="meta-llama/Meta-Llama-3-8B-Instruct", base_url="http://127.0.0.1:8000")
920920

921921

922-
@pytest.mark.parametrize("stop_signal", [b"data: [DONE]", b"data: [DONE]\n", b"data: [DONE] "])
922+
@pytest.mark.parametrize(
923+
"stop_signal",
924+
[
925+
b"data: [DONE]",
926+
b"data: [DONE]\n",
927+
b"data: [DONE] ",
928+
],
929+
)
923930
def test_stream_text_generation_response(stop_signal: bytes):
924931
data = [
925932
b'data: {"index":1,"token":{"id":4560,"text":" trying","logprob":-2.078125,"special":false},"generated_text":null,"details":null}',
@@ -935,7 +942,14 @@ def test_stream_text_generation_response(stop_signal: bytes):
935942
assert output == [" trying", " to"]
936943

937944

938-
@pytest.mark.parametrize("stop_signal", [b"data: [DONE]", b"data: [DONE]\n", b"data: [DONE] "])
945+
@pytest.mark.parametrize(
946+
"stop_signal",
947+
[
948+
b"data: [DONE]",
949+
b"data: [DONE]\n",
950+
b"data: [DONE] ",
951+
],
952+
)
939953
def test_stream_chat_completion_response(stop_signal: bytes):
940954
data = [
941955
b'data: {"object":"chat.completion.chunk","id":"","created":1721737661,"model":"","system_fingerprint":"2.1.2-dev0-sha-5fca30e","choices":[{"index":0,"delta":{"role":"assistant","content":"Both"},"logprobs":null,"finish_reason":null}]}',
@@ -952,6 +966,20 @@ def test_stream_chat_completion_response(stop_signal: bytes):
952966
assert output[1].choices[0].delta.content == " Rust"
953967

954968

969+
def test_chat_completion_error_in_stream():
970+
"""
971+
Regression test for https://github.com/huggingface/huggingface_hub/issues/2514.
972+
When an error is encountered in the stream, it should raise a TextGenerationError (e.g. a ValidationError).
973+
"""
974+
data = [
975+
b'data: {"object":"chat.completion.chunk","id":"","created":1721737661,"model":"","system_fingerprint":"2.1.2-dev0-sha-5fca30e","choices":[{"index":0,"delta":{"role":"assistant","content":"Both"},"logprobs":null,"finish_reason":null}]}',
976+
b'data: {"error":"Input validation error: `inputs` tokens + `max_new_tokens` must be <= 4096. Given: 6 `inputs` tokens and 4091 `max_new_tokens`","error_type":"validation"}',
977+
]
978+
with pytest.raises(ValidationError):
979+
for token in _stream_chat_completion_response(data):
980+
pass
981+
982+
955983
INFERENCE_API_URL = "https://api-inference.huggingface.co/models"
956984
INFERENCE_ENDPOINT_URL = "https://rur2d6yoccusjxgn.us-east-1.aws.endpoints.huggingface.cloud" # example
957985
LOCAL_TGI_URL = "http://0.0.0.0:8080"

0 commit comments

Comments
 (0)