Skip to content

Commit

Permalink
Make @retry_endpoint a default for all test (#1725)
Browse files Browse the repository at this point in the history
* Make @retry_endpoint a default for any test

* simplificaiton

* works
  • Loading branch information
Wauplin authored Oct 12, 2023
1 parent 5a4eab0 commit 51d9e94
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 157 deletions.
63 changes: 61 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,27 @@
import os
import shutil
import time
from functools import wraps
from pathlib import Path
from typing import Generator
from typing import Generator, List

import pytest
from _pytest.fixtures import SubRequest
from _pytest.python import Function as PytestFunction
from requests.exceptions import HTTPError

import huggingface_hub
from huggingface_hub import HfApi, HfFolder
from huggingface_hub.utils import SoftTemporaryDirectory
from huggingface_hub.utils import SoftTemporaryDirectory, logging
from huggingface_hub.utils._typing import CallableT

from .testing_constants import ENDPOINT_PRODUCTION, PRODUCTION_TOKEN
from .testing_utils import repo_name, set_write_permission_and_retry


logger = logging.get_logger(__name__)


@pytest.fixture
def fx_cache_dir(request: SubRequest) -> Generator[None, None, None]:
"""Add a `cache_dir` attribute pointing to a temporary directory in tests.
Expand Down Expand Up @@ -75,6 +83,57 @@ def disable_experimental_warnings(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(huggingface_hub.constants, "HF_HUB_DISABLE_EXPERIMENTAL_WARNING", True)


def retry_on_transient_error(fn: CallableT) -> CallableT:
"""
Retry test if failure because of unavailable service or race condition.
Tests are retried up to 3 times, waiting 5s between each try.
"""
NUMBER_OF_TRIES = 3
WAIT_TIME = 5

@wraps(fn)
def _inner(*args, **kwargs):
retry_count = 0
while True:
try:
return fn(*args, **kwargs)
except HTTPError as e:
if retry_count >= NUMBER_OF_TRIES:
raise
if e.response.status_code == 504:
logger.info(
f"Attempt {retry_count} failed with a 504 error. Retrying new execution in"
f" {WAIT_TIME} second(s)..."
)
else:
raise
except OSError:
if retry_count >= NUMBER_OF_TRIES:
raise
logger.info(
"Race condition met where we tried to `clone` before fully deleting a repository. Retrying new"
f" execution in {WAIT_TIME} second(s)..."
)
time.sleep(WAIT_TIME)
retry_count += 1

return _inner


def pytest_collection_modifyitems(items: List[PytestFunction]):
"""Alter all tests to retry on transient errors.
Note: equivalent to the previously used `@retry_endpoint` decorator, but tests do
not have to be decorated individually anymore.
"""
# called after collection is completed
# you can modify the ``items`` list
# see https://docs.pytest.org/en/7.3.x/how-to/writing_hook_functions.html
for item in items:
item.obj = retry_on_transient_error(item.obj)


@pytest.fixture
def fx_production_space(request: SubRequest) -> Generator[None, None, None]:
"""Add a `repo_id` attribute referencing a Space repo on the production Hub.
Expand Down
Loading

0 comments on commit 51d9e94

Please sign in to comment.