Skip to content

Commit 2b95d2b

Browse files
authored
Fix error handling on HTTP 401 (#1904)
* Fix error handling on HTTP 401g * add test * fix test * fix regex for resolve endpoints
1 parent ab6bff1 commit 2b95d2b

File tree

2 files changed

+77
-18
lines changed

2 files changed

+77
-18
lines changed

src/huggingface_hub/utils/_errors.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,27 @@
1+
import re
12
from typing import Optional
23

34
from requests import HTTPError, Response
45

5-
from ..constants import INFERENCE_ENDPOINTS_ENDPOINT
66
from ._fixes import JSONDecodeError
77

88

9+
REPO_API_REGEX = re.compile(
10+
r"""
11+
# staging or production endpoint
12+
^https://(hub-ci.)?huggingface.co
13+
(
14+
# on /api/repo_type/repo_id
15+
/api/(models|datasets|spaces)/(.+)
16+
|
17+
# or /repo_id/resolve/revision/...
18+
/(.+)/resolve/(.+)
19+
)
20+
""",
21+
flags=re.VERBOSE,
22+
)
23+
24+
925
class FileMetadataError(OSError):
1026
"""Error triggered when the metadata of a file on the Hub cannot be retrieved (missing ETag or commit_hash).
1127
@@ -285,25 +301,12 @@ def hf_raise_for_status(response: Response, endpoint_name: Optional[str] = None)
285301
)
286302
raise GatedRepoError(message, response) from e
287303

288-
elif (
289-
response.status_code == 401
290-
and response.request.url is not None
291-
and "/api/collections" in response.request.url
292-
):
293-
# Collection not found. We don't raise a custom error for this.
294-
# This prevent from raising a misleading `RepositoryNotFoundError` (see below).
295-
pass
296-
297-
elif (
304+
elif error_code == "RepoNotFound" or (
298305
response.status_code == 401
306+
and response.request is not None
299307
and response.request.url is not None
300-
and INFERENCE_ENDPOINTS_ENDPOINT in response.request.url
308+
and REPO_API_REGEX.search(response.request.url) is not None
301309
):
302-
# Not enough permission to list Inference Endpoints from this org. We don't raise a custom error for this.
303-
# This prevent from raising a misleading `RepositoryNotFoundError` (see below).
304-
pass
305-
306-
elif error_code == "RepoNotFound" or response.status_code == 401:
307310
# 401 is misleading as it is returned for:
308311
# - private and gated repos if user is not authenticated
309312
# - missing repos

tests/test_utils_errors.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import unittest
22

3+
import pytest
34
from requests.models import PreparedRequest, Response
45

56
from huggingface_hub.utils._errors import (
7+
REPO_API_REGEX,
68
BadRequestError,
79
EntryNotFoundError,
810
HfHubHTTPError,
@@ -23,17 +25,30 @@ def test_hf_raise_for_status_repo_not_found(self) -> None:
2325
self.assertEqual(context.exception.response.status_code, 404)
2426
self.assertIn("Request ID: 123", str(context.exception))
2527

26-
def test_hf_raise_for_status_repo_not_found_without_error_code(self) -> None:
28+
def test_hf_raise_for_status_401_repo_url(self) -> None:
2729
response = Response()
2830
response.headers = {"X-Request-Id": 123}
2931
response.status_code = 401
3032
response.request = PreparedRequest()
33+
response.request.url = "https://huggingface.co/api/models/username/reponame"
3134
with self.assertRaisesRegex(RepositoryNotFoundError, "Repository Not Found") as context:
3235
hf_raise_for_status(response)
3336

3437
self.assertEqual(context.exception.response.status_code, 401)
3538
self.assertIn("Request ID: 123", str(context.exception))
3639

40+
def test_hf_raise_for_status_401_not_repo_url(self) -> None:
41+
response = Response()
42+
response.headers = {"X-Request-Id": 123}
43+
response.status_code = 401
44+
response.request = PreparedRequest()
45+
response.request.url = "https://huggingface.co/api/collections"
46+
with self.assertRaises(HfHubHTTPError) as context:
47+
hf_raise_for_status(response)
48+
49+
self.assertEqual(context.exception.response.status_code, 401)
50+
self.assertIn("Request ID: 123", str(context.exception))
51+
3752
def test_hf_raise_for_status_revision_not_found(self) -> None:
3853
response = Response()
3954
response.headers = {"X-Error-Code": "RevisionNotFound", "X-Request-Id": 123}
@@ -239,3 +254,44 @@ def test_hf_hub_http_error_init_with_error_message_duplicated_in_header_and_body
239254
"this is a message\n\nError message duplicated in headers and body.",
240255
)
241256
self.assertEqual(error.server_message, "Error message duplicated in headers and body.")
257+
258+
259+
@pytest.mark.parametrize(
260+
("url", "should_match"),
261+
[
262+
# Listing endpoints => False
263+
("https://huggingface.co/api/models", False),
264+
("https://huggingface.co/api/datasets", False),
265+
("https://huggingface.co/api/spaces", False),
266+
# Create repo endpoint => False
267+
("https://huggingface.co/api/repos/create", False),
268+
# Collection endpoints => False
269+
("https://huggingface.co/api/collections", False),
270+
("https://huggingface.co/api/collections/foo/bar", False),
271+
# Repo endpoints => True
272+
("https://huggingface.co/api/models/repo_id", True),
273+
("https://huggingface.co/api/datasets/repo_id", True),
274+
("https://huggingface.co/api/spaces/repo_id", True),
275+
("https://huggingface.co/api/models/username/repo_name/refs/main", True),
276+
("https://huggingface.co/api/datasets/username/repo_name/refs/main", True),
277+
("https://huggingface.co/api/spaces/username/repo_name/refs/main", True),
278+
# Inference Endpoint => False
279+
("https://api.endpoints.huggingface.cloud/v2/endpoint/namespace", False),
280+
# Staging Endpoint => True
281+
("https://hub-ci.huggingface.co/api/models/repo_id", True),
282+
("https://hub-ci.huggingface.co/api/datasets/repo_id", True),
283+
("https://hub-ci.huggingface.co/api/spaces/repo_id", True),
284+
# /resolve Endpoint => True
285+
("https://huggingface.co/gpt2/resolve/main/README.md", True),
286+
("https://huggingface.co/datasets/google/fleurs/resolve/revision/README.md", True),
287+
# Regression tests
288+
("https://huggingface.co/bert-base/resolve/main/pytorch_model.bin", True),
289+
("https://hub-ci.huggingface.co/__DUMMY_USER__/repo-1470b5/resolve/main/file.txt", True),
290+
],
291+
)
292+
def test_repo_api_regex(url: str, should_match: bool) -> None:
293+
"""Test the regex used to match repo API URLs."""
294+
if should_match:
295+
assert REPO_API_REGEX.match(url)
296+
else:
297+
assert REPO_API_REGEX.match(url) is None

0 commit comments

Comments
 (0)