Skip to content

Commit

Permalink
Merge pull request #37 from deepghs/dev/retry
Browse files Browse the repository at this point in the history
dev(narugo): add retry session in entries
  • Loading branch information
narugo1992 authored Aug 9, 2024
2 parents 3459e6c + ced5b21 commit d63e354
Show file tree
Hide file tree
Showing 12 changed files with 257 additions and 2 deletions.
1 change: 1 addition & 0 deletions docs/source/api_doc/utils/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ hfutils.utils
download
number
path
session
tqdm_
walk

31 changes: 31 additions & 0 deletions docs/source/api_doc/utils/session.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
hfutils.utils.session
=================================

.. currentmodule:: hfutils.utils.session

.. automodule:: hfutils.utils.session



TimeoutHTTPAdapter
-----------------------------------------------------

.. autoclass:: TimeoutHTTPAdapter
:members: __init__, send



get_requests_session
-----------------------------------------------------

.. autofunction:: get_requests_session



get_random_ua
-----------------------------------------------------

.. autofunction:: get_random_ua



4 changes: 4 additions & 0 deletions hfutils/entry/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
from typing import Optional

import click
from huggingface_hub import configure_http_backend

from .base import CONTEXT_SETTINGS, command_wrap, ClickErrorException
from ..operate import download_file_to_file, download_archive_as_directory, download_directory_as_directory
from ..operate.base import REPO_TYPES, RepoTypeTyping
from ..utils import get_requests_session


class NoRemotePathAssignedWithDownload(ClickErrorException):
Expand Down Expand Up @@ -84,6 +86,8 @@ def download(
:param tmpdir: Use custom temporary Directory.
:type tmpdir: str, optional
"""
configure_http_backend(get_requests_session)

if tmpdir:
os.environ['TMPDIR'] = tmpdir

Expand Down
6 changes: 5 additions & 1 deletion hfutils/entry/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@

import click
from hbutils.string import plural_word
from huggingface_hub import configure_http_backend

from .base import CONTEXT_SETTINGS
from ..cache import delete_detached_cache
from ..index import hf_tar_validate, tar_create_index
from ..operate import get_hf_fs, download_file_to_file, upload_directory_as_directory
from ..operate.base import REPO_TYPES, RepoTypeTyping, get_hf_client
from ..utils import tqdm, hf_fs_path, parse_hf_fs_path, TemporaryDirectory, hf_normpath, ColoredFormatter
from ..utils import tqdm, hf_fs_path, parse_hf_fs_path, TemporaryDirectory, hf_normpath, ColoredFormatter, \
get_requests_session


def _add_index_subcommand(cli: click.Group) -> click.Group:
Expand Down Expand Up @@ -86,6 +88,8 @@ def index(repo_id: str, idx_repo_id: Optional[str], repo_type: RepoTypeTyping, r
This function is typically invoked through the CLI interface, like:
$ python script.py index -r my_repo -x my_index_repo -t dataset -R main --min_upload_interval 120
"""
configure_http_backend(get_requests_session)

logger = logging.getLogger()
logger.setLevel(logging.INFO)
console_handler = logging.StreamHandler()
Expand Down
4 changes: 4 additions & 0 deletions hfutils/entry/ls.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@

import click
import tzlocal
from huggingface_hub import configure_http_backend
from huggingface_hub.hf_api import RepoFolder, RepoFile

from .base import CONTEXT_SETTINGS
from ..operate.base import REPO_TYPES, get_hf_client
from ..utils import get_requests_session

mimetypes.add_type('image/webp', '.webp')

Expand Down Expand Up @@ -124,6 +126,8 @@ def ls(repo_id: str, repo_type: str, dir_in_repo, revision: str, show_all: bool,
:param show_detailed: Flag to indicate whether to show detailed file information.
:type show_detailed: bool
"""
configure_http_backend(get_requests_session)

hf_client = get_hf_client()
items: List[ListItem] = []
for item in hf_client.list_repo_tree(
Expand Down
4 changes: 4 additions & 0 deletions hfutils/entry/ls_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
from typing import Optional

import click
from huggingface_hub import configure_http_backend
from huggingface_hub.utils import LocalTokenNotFoundError

from .base import CONTEXT_SETTINGS, ClickErrorException
from ..operate.base import REPO_TYPES, get_hf_client
from ..utils import get_requests_session


class NoLocalAuthentication(ClickErrorException):
Expand Down Expand Up @@ -46,6 +48,8 @@ def ls(author: Optional[str], repo_type: str, pattern: str):
:param pattern: Pattern of the repository names.
:type pattern: str
"""
configure_http_backend(get_requests_session)

hf_client = get_hf_client()
if not author:
try:
Expand Down
4 changes: 4 additions & 0 deletions hfutils/entry/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
from typing import Optional

import click
from huggingface_hub import configure_http_backend

from .base import CONTEXT_SETTINGS, command_wrap, ClickErrorException
from ..operate import upload_file_to_file, upload_directory_as_archive, upload_directory_as_directory
from ..operate.base import REPO_TYPES, RepoTypeTyping, get_hf_client
from ..utils import get_requests_session


class NoRemotePathAssignedWithUpload(ClickErrorException):
Expand Down Expand Up @@ -78,6 +80,8 @@ def upload(repo_id: str, repo_type: RepoTypeTyping,
:param public: Set public repository when created.
:type public: bool
"""
configure_http_backend(get_requests_session)

if not file_in_repo and not archive_in_repo and not dir_in_repo:
raise NoRemotePathAssignedWithUpload('No remote path in repository assigned.\n'
'One of the -f, -a, or -d option is required.')
Expand Down
4 changes: 4 additions & 0 deletions hfutils/entry/whoami.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import click
from hbutils.string import plural_word
from huggingface_hub import configure_http_backend
from huggingface_hub.utils import LocalTokenNotFoundError

from .base import CONTEXT_SETTINGS
from ..operate.base import get_hf_client
from ..utils import get_requests_session


def _add_whoami_subcommand(cli: click.Group) -> click.Group:
Expand All @@ -28,6 +30,8 @@ def whoami():
This function retrieves the current user's identification from the Hugging Face Hub API and displays it.
"""
configure_http_backend(get_requests_session)

hf_client = get_hf_client()
try:
info = hf_client.whoami()
Expand Down
1 change: 1 addition & 0 deletions hfutils/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .logging import ColoredFormatter
from .number import number_to_tag
from .path import hf_normpath, hf_fs_path, parse_hf_fs_path, HfFileSystemPath
from .session import TimeoutHTTPAdapter, get_requests_session, get_random_ua
from .temp import TemporaryDirectory
from .tqdm_ import tqdm
from .walk import walk_files
115 changes: 115 additions & 0 deletions hfutils/utils/session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
"""
This module provides functionality for creating and managing HTTP sessions with customizable retry logic,
timeout settings, and user-agent rotation using random user-agent generation. It is designed to help with
robust web scraping and API consumption by handling common HTTP errors and timeouts gracefully.
Main Features:
- Automatic retries on specified HTTP response status codes.
- Configurable request timeout.
- Rotating user-agent for each session to mimic different browsers and operating systems.
- Optional SSL verification.
"""

from functools import lru_cache
from typing import Optional, Dict

import requests
from random_user_agent.params import SoftwareName, OperatingSystem
from random_user_agent.user_agent import UserAgent
from requests.adapters import HTTPAdapter, Retry

DEFAULT_TIMEOUT = 15 # seconds


class TimeoutHTTPAdapter(HTTPAdapter):
"""
A custom HTTPAdapter that enforces a default timeout on all requests.
:param args: Variable length argument list for HTTPAdapter.
:param kwargs: Arbitrary keyword arguments. 'timeout' can be specified to set a custom timeout.
"""

def __init__(self, *args, **kwargs):
self.timeout = DEFAULT_TIMEOUT
if "timeout" in kwargs:
self.timeout = kwargs["timeout"]
del kwargs["timeout"]
super().__init__(*args, **kwargs)

def send(self, request, **kwargs):
"""
Sends the Request object, applying the timeout setting.
:param request: The Request object to send.
:type request: requests.PreparedRequest
:param kwargs: Keyword arguments that may contain 'timeout'.
:return: The response to the request.
"""
timeout = kwargs.get("timeout")
if timeout is None:
kwargs["timeout"] = self.timeout
return super().send(request, **kwargs)


def get_requests_session(max_retries: int = 5, timeout: int = DEFAULT_TIMEOUT, verify: bool = True,
headers: Optional[Dict[str, str]] = None, session: Optional[requests.Session] = None) \
-> requests.Session:
"""
Creates a requests session with retry logic, timeout settings, and random user-agent headers.
:param max_retries: Maximum number of retries on failed requests.
:type max_retries: int
:param timeout: Request timeout in seconds.
:type timeout: int
:param verify: Whether to verify SSL certificates.
:type verify: bool
:param headers: Additional headers to include in the requests.
:type headers: Optional[Dict[str, str]]
:param session: An existing requests.Session instance to use.
:type session: Optional[requests.Session]
:return: A configured requests.Session object.
:rtype: requests.Session
"""
session = session or requests.session()
retries = Retry(
total=max_retries, backoff_factor=1,
status_forcelist=[408, 429, 500, 501, 502, 503, 504, 505, 506, 507, 509, 510, 511],
allowed_methods=["HEAD", "GET", "POST", "PUT", "DELETE", "OPTIONS", "TRACE"],
)
adapter = TimeoutHTTPAdapter(max_retries=retries, timeout=timeout, pool_connections=32, pool_maxsize=32)
session.mount('http://', adapter)
session.mount('https://', adapter)
session.headers.update({
"User-Agent": get_random_ua(),
**dict(headers or {}),
})
if not verify:
session.verify = False

return session


@lru_cache()
def _ua_pool():
"""
Creates and caches a UserAgent rotator instance with a specified number of user agents.
:return: A UserAgent rotator instance.
:rtype: UserAgent
"""
software_names = [SoftwareName.CHROME.value, SoftwareName.FIREFOX.value, SoftwareName.EDGE.value]
operating_systems = [OperatingSystem.WINDOWS.value, OperatingSystem.MACOS.value]

user_agent_rotator = UserAgent(software_names=software_names, operating_systems=operating_systems, limit=1000)
return user_agent_rotator


def get_random_ua():
"""
Retrieves a random user agent string from the cached UserAgent rotator.
:return: A random user agent string.
:rtype: str
"""
return _ua_pool().get_random_user_agent()
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ click>=7
tzlocal
natsort
urlobject
fsspec>=2024
fsspec>=2024
random_user_agent
82 changes: 82 additions & 0 deletions test/utils/test_session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from unittest.mock import patch, Mock

import pytest
import requests
from huggingface_hub import hf_hub_url
from requests.adapters import HTTPAdapter

from hfutils.utils.session import TimeoutHTTPAdapter, get_requests_session, get_random_ua


@pytest.fixture
def mock_requests_session():
with patch('requests.session') as mock_session:
yield mock_session.return_value


@pytest.fixture
def mock_ua_pool():
with patch('hfutils.utils.session._ua_pool') as mock_pool:
mock_pool.return_value.get_random_user_agent.return_value = 'MockUserAgent'
yield mock_pool


@pytest.fixture()
def example_url():
return hf_hub_url(
repo_id='deepghs/danbooru_newest',
repo_type='dataset',
filename='README.md'
)


@pytest.mark.unittest
class TestUtilsSession:
def test_timeout_http_adapter_init(self, ):
adapter = TimeoutHTTPAdapter()
assert adapter.timeout == 15

adapter = TimeoutHTTPAdapter(timeout=30)
assert adapter.timeout == 30

def test_timeout_http_adapter_send(self, ):
adapter = TimeoutHTTPAdapter(timeout=10)
mock_request = Mock()
mock_kwargs = {}

with patch.object(HTTPAdapter, 'send') as mock_send:
adapter.send(mock_request, **mock_kwargs)
mock_send.assert_called_once_with(mock_request, timeout=10)

mock_kwargs = {'timeout': 20}
with patch.object(HTTPAdapter, 'send') as mock_send:
adapter.send(mock_request, **mock_kwargs)
mock_send.assert_called_once_with(mock_request, timeout=20)

def test_get_requests_session(self, mock_ua_pool):
session = get_requests_session()
assert isinstance(session, requests.Session)
assert 'User-Agent' in session.headers
assert session.headers['User-Agent'] == 'MockUserAgent'

custom_headers = {'Custom-Header': 'Value'}
session = get_requests_session(headers=custom_headers)
assert 'Custom-Header' in session.headers
assert session.headers['Custom-Header'] == 'Value'

session = get_requests_session(verify=False)
assert session.verify is False

existing_session = requests.Session()
session = get_requests_session(session=existing_session)
assert session is existing_session

def test_get_requests_session_with_custom_params(self):
session = get_requests_session(max_retries=3, timeout=30)
assert isinstance(session, requests.Session)
# You might want to add more assertions here to check if the custom parameters are applied correctly

def test_get_random_ua(self, mock_ua_pool):
ua = get_random_ua()
assert ua == 'MockUserAgent'
mock_ua_pool.return_value.get_random_user_agent.assert_called_once()

0 comments on commit d63e354

Please sign in to comment.