47
47
hf_hub_download ,
48
48
)
49
49
from huggingface_hub .constants import ALL_INFERENCE_API_FRAMEWORKS , MAIN_INFERENCE_API_FRAMEWORKS
50
- from huggingface_hub .errors import HfHubHTTPError
50
+ from huggingface_hub .errors import HfHubHTTPError , ValidationError
51
51
from huggingface_hub .inference ._client import _open_as_binary
52
52
from huggingface_hub .inference ._common import (
53
53
_stream_chat_completion_response ,
@@ -919,7 +919,14 @@ def test_model_and_base_url_mutually_exclusive(self):
919
919
InferenceClient (model = "meta-llama/Meta-Llama-3-8B-Instruct" , base_url = "http://127.0.0.1:8000" )
920
920
921
921
922
- @pytest .mark .parametrize ("stop_signal" , [b"data: [DONE]" , b"data: [DONE]\n " , b"data: [DONE] " ])
922
+ @pytest .mark .parametrize (
923
+ "stop_signal" ,
924
+ [
925
+ b"data: [DONE]" ,
926
+ b"data: [DONE]\n " ,
927
+ b"data: [DONE] " ,
928
+ ],
929
+ )
923
930
def test_stream_text_generation_response (stop_signal : bytes ):
924
931
data = [
925
932
b'data: {"index":1,"token":{"id":4560,"text":" trying","logprob":-2.078125,"special":false},"generated_text":null,"details":null}' ,
@@ -935,7 +942,14 @@ def test_stream_text_generation_response(stop_signal: bytes):
935
942
assert output == [" trying" , " to" ]
936
943
937
944
938
- @pytest .mark .parametrize ("stop_signal" , [b"data: [DONE]" , b"data: [DONE]\n " , b"data: [DONE] " ])
945
+ @pytest .mark .parametrize (
946
+ "stop_signal" ,
947
+ [
948
+ b"data: [DONE]" ,
949
+ b"data: [DONE]\n " ,
950
+ b"data: [DONE] " ,
951
+ ],
952
+ )
939
953
def test_stream_chat_completion_response (stop_signal : bytes ):
940
954
data = [
941
955
b'data: {"object":"chat.completion.chunk","id":"","created":1721737661,"model":"","system_fingerprint":"2.1.2-dev0-sha-5fca30e","choices":[{"index":0,"delta":{"role":"assistant","content":"Both"},"logprobs":null,"finish_reason":null}]}' ,
@@ -952,6 +966,20 @@ def test_stream_chat_completion_response(stop_signal: bytes):
952
966
assert output [1 ].choices [0 ].delta .content == " Rust"
953
967
954
968
969
+ def test_chat_completion_error_in_stream ():
970
+ """
971
+ Regression test for https://github.com/huggingface/huggingface_hub/issues/2514.
972
+ When an error is encountered in the stream, it should raise a TextGenerationError (e.g. a ValidationError).
973
+ """
974
+ data = [
975
+ b'data: {"object":"chat.completion.chunk","id":"","created":1721737661,"model":"","system_fingerprint":"2.1.2-dev0-sha-5fca30e","choices":[{"index":0,"delta":{"role":"assistant","content":"Both"},"logprobs":null,"finish_reason":null}]}' ,
976
+ b'data: {"error":"Input validation error: `inputs` tokens + `max_new_tokens` must be <= 4096. Given: 6 `inputs` tokens and 4091 `max_new_tokens`","error_type":"validation"}' ,
977
+ ]
978
+ with pytest .raises (ValidationError ):
979
+ for token in _stream_chat_completion_response (data ):
980
+ pass
981
+
982
+
955
983
INFERENCE_API_URL = "https://api-inference.huggingface.co/models"
956
984
INFERENCE_ENDPOINT_URL = "https://rur2d6yoccusjxgn.us-east-1.aws.endpoints.huggingface.cloud" # example
957
985
LOCAL_TGI_URL = "http://0.0.0.0:8080"
0 commit comments