Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 48 additions & 2 deletions src/crawlee/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,17 @@ class AddRequestsKwargs(EnqueueLinksKwargs):
requests: Sequence[str | Request]
"""Requests to be added to the `RequestManager`."""

rq_id: str | None
"""ID of the `RequestQueue` to add the requests to. Only one of `rq_id`, `rq_name` or `rq_alias` can be provided."""

rq_name: str | None
"""Name of the `RequestQueue` to add the requests to. Only one of `rq_id`, `rq_name` or `rq_alias` can be provided.
"""

rq_alias: str | None
"""Alias of the `RequestQueue` to add the requests to. Only one of `rq_id`, `rq_name` or `rq_alias` can be provided.
"""


class PushDataKwargs(TypedDict):
"""Keyword arguments for dataset's `push_data` method."""
Expand Down Expand Up @@ -261,10 +272,18 @@ def __init__(self, *, key_value_store_getter: GetKeyValueStoreFunction) -> None:
async def add_requests(
self,
requests: Sequence[str | Request],
rq_id: str | None = None,
rq_name: str | None = None,
rq_alias: str | None = None,
**kwargs: Unpack[EnqueueLinksKwargs],
) -> None:
"""Track a call to the `add_requests` context helper."""
self.add_requests_calls.append(AddRequestsKwargs(requests=requests, **kwargs))
specified_params = sum(1 for param in [rq_id, rq_name, rq_alias] if param is not None)
if specified_params > 1:
raise ValueError('Only one of `rq_id`, `rq_name` or `rq_alias` can be provided.')
self.add_requests_calls.append(
AddRequestsKwargs(requests=requests, rq_id=rq_id, rq_name=rq_name, rq_alias=rq_alias, **kwargs)
)

async def push_data(
self,
Expand Down Expand Up @@ -311,12 +330,21 @@ class AddRequestsFunction(Protocol):
def __call__(
self,
requests: Sequence[str | Request],
rq_id: str | None = None,
rq_name: str | None = None,
rq_alias: str | None = None,
**kwargs: Unpack[EnqueueLinksKwargs],
) -> Coroutine[None, None, None]:
"""Call dunder method.

Args:
requests: Requests to be added to the `RequestManager`.
rq_id: ID of the `RequestQueue` to add the requests to. Only one of `rq_id`, `rq_name` or `rq_alias` can be
provided.
rq_name: Name of the `RequestQueue` to add the requests to. Only one of `rq_id`, `rq_name` or `rq_alias`
can be provided.
rq_alias: Alias of the `RequestQueue` to add the requests to. Only one of `rq_id`, `rq_name` or `rq_alias`
can be provided.
**kwargs: Additional keyword arguments.
"""

Expand Down Expand Up @@ -344,12 +372,21 @@ def __call__(
label: str | None = None,
user_data: dict[str, Any] | None = None,
transform_request_function: Callable[[RequestOptions], RequestOptions | RequestTransformAction] | None = None,
rq_id: str | None = None,
rq_name: str | None = None,
rq_alias: str | None = None,
**kwargs: Unpack[EnqueueLinksKwargs],
) -> Coroutine[None, None, None]: ...

@overload
def __call__(
self, *, requests: Sequence[str | Request] | None = None, **kwargs: Unpack[EnqueueLinksKwargs]
self,
*,
requests: Sequence[str | Request] | None = None,
rq_id: str | None = None,
rq_name: str | None = None,
rq_alias: str | None = None,
**kwargs: Unpack[EnqueueLinksKwargs],
) -> Coroutine[None, None, None]: ...

def __call__(
Expand All @@ -360,6 +397,9 @@ def __call__(
user_data: dict[str, Any] | None = None,
transform_request_function: Callable[[RequestOptions], RequestOptions | RequestTransformAction] | None = None,
requests: Sequence[str | Request] | None = None,
rq_id: str | None = None,
rq_name: str | None = None,
rq_alias: str | None = None,
**kwargs: Unpack[EnqueueLinksKwargs],
) -> Coroutine[None, None, None]:
"""Call enqueue links function.
Expand All @@ -377,6 +417,12 @@ def __call__(
- `'skip'` to exclude the request from being enqueued,
- `'unchanged'` to use the original request options without modification.
requests: Requests to be added to the `RequestManager`.
rq_id: ID of the `RequestQueue` to add the requests to. Only one of `rq_id`, `rq_name` or `rq_alias` can be
provided.
rq_name: Name of the `RequestQueue` to add the requests to. Only one of `rq_id`, `rq_name` or `rq_alias`
can be provided.
rq_alias: Alias of the `RequestQueue` to add the requests to. Only one of `rq_id`, `rq_name` or `rq_alias`
can be provided.
**kwargs: Additional keyword arguments.
"""

Expand Down
30 changes: 28 additions & 2 deletions src/crawlee/crawlers/_basic/_basic_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,6 +944,9 @@ async def enqueue_links(
transform_request_function: Callable[[RequestOptions], RequestOptions | RequestTransformAction]
| None = None,
requests: Sequence[str | Request] | None = None,
rq_id: str | None = None,
rq_name: str | None = None,
rq_alias: str | None = None,
**kwargs: Unpack[EnqueueLinksKwargs],
) -> None:
kwargs.setdefault('strategy', 'same-hostname')
Expand All @@ -955,7 +958,9 @@ async def enqueue_links(
'`transform_request_function` arguments when `requests` is provided.'
)
# Add directly passed requests.
await context.add_requests(requests or list[str | Request](), **kwargs)
await context.add_requests(
requests or list[str | Request](), rq_id=rq_id, rq_name=rq_name, rq_alias=rq_alias, **kwargs
)
else:
# Add requests from extracted links.
await context.add_requests(
Expand All @@ -965,6 +970,9 @@ async def enqueue_links(
user_data=user_data,
transform_request_function=transform_request_function,
),
rq_id=rq_id,
rq_name=rq_name,
rq_alias=rq_alias,
**kwargs,
)

Expand Down Expand Up @@ -1241,10 +1249,28 @@ async def _commit_request_handler_result(self, context: BasicCrawlingContext) ->
"""Commit request handler result for the input `context`. Result is taken from `_context_result_map`."""
result = self._context_result_map[context]

request_manager = await self.get_request_manager()
base_request_manager = await self.get_request_manager()

origin = context.request.loaded_url or context.request.url

for add_requests_call in result.add_requests_calls:
rq_id = add_requests_call.get('rq_id')
rq_name = add_requests_call.get('rq_name')
rq_alias = add_requests_call.get('rq_alias')
specified_params = sum(1 for param in [rq_id, rq_name, rq_alias] if param is not None)
if specified_params > 1:
raise ValueError('You can only provide one of `rq_id`, `rq_name` or `rq_alias` arguments.')
if rq_id or rq_name or rq_alias:
request_manager: RequestManager | RequestQueue = await RequestQueue.open(
id=rq_id,
name=rq_name,
alias=rq_alias,
storage_client=self._service_locator.get_storage_client(),
configuration=self._service_locator.get_configuration(),
)
else:
request_manager = base_request_manager

requests = list[Request]()

base_url = url if (url := add_requests_call.get('base_url')) else origin
Expand Down
68 changes: 68 additions & 0 deletions tests/unit/crawlers/_basic/test_basic_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1549,3 +1549,71 @@ def listener(event_data: EventCrawlerStatusData) -> None:
event_manager.off(event=Event.CRAWLER_STATUS, listener=listener)

assert status_message_listener.called


@pytest.mark.parametrize(
('queue_name', 'queue_alias', 'by_id'),
[
pytest.param('named-queue', None, False, id='with rq_name'),
pytest.param(None, 'alias-queue', False, id='with rq_alias'),
pytest.param('id-queue', None, True, id='with rq_id'),
],
)
async def test_add_requests_with_rq_param(queue_name: str | None, queue_alias: str | None, *, by_id: bool) -> None:
crawler = BasicCrawler()
rq = await RequestQueue.open(name=queue_name, alias=queue_alias)
if by_id:
queue_id = rq.id
queue_name = None
else:
queue_id = None
visit_urls = set()

check_requests = [
Request.from_url('https://a.placeholder.com'),
Request.from_url('https://b.placeholder.com'),
Request.from_url('https://c.placeholder.com'),
]

@crawler.router.default_handler
async def handler(context: BasicCrawlingContext) -> None:
visit_urls.add(context.request.url)
await context.add_requests(check_requests, rq_id=queue_id, rq_name=queue_name, rq_alias=queue_alias)

await crawler.run(['https://start.placeholder.com'])

requests_from_queue = []
while request := await rq.fetch_next_request():
requests_from_queue.append(request)

assert requests_from_queue == check_requests
assert visit_urls == {'https://start.placeholder.com'}

await rq.drop()


@pytest.mark.parametrize(
('queue_name', 'queue_alias', 'queue_id'),
[
pytest.param('named-queue', 'alias-queue', None, id='rq_name and rq_alias'),
pytest.param('named-queue', None, 'id-queue', id='rq_name and rq_id'),
pytest.param(None, 'alias-queue', 'id-queue', id='rq_alias and rq_id'),
pytest.param('named-queue', 'alias-queue', 'id-queue', id='rq_name and rq_alias and rq_id'),
],
)
async def test_add_requests_error_with_multi_params(
queue_id: str | None, queue_name: str | None, queue_alias: str | None
) -> None:
crawler = BasicCrawler()

@crawler.router.default_handler
async def handler(context: BasicCrawlingContext) -> None:
with pytest.raises(ValueError, match='Only one of `rq_id`, `rq_name` or `rq_alias` can be set'):
await context.add_requests(
[Request.from_url('https://a.placeholder.com')],
rq_id=queue_id,
rq_name=queue_name,
rq_alias=queue_alias,
)

await crawler.run(['https://start.placeholder.com'])
107 changes: 107 additions & 0 deletions tests/unit/crawlers/_beautifulsoup/test_beautifulsoup_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
from typing import TYPE_CHECKING
from unittest import mock

import pytest

from crawlee import ConcurrencySettings, Glob, HttpHeaders, RequestTransformAction, SkippedReason
from crawlee.crawlers import BeautifulSoupCrawler, BeautifulSoupCrawlingContext
from crawlee.storages import RequestQueue

if TYPE_CHECKING:
from yarl import URL
Expand Down Expand Up @@ -198,3 +201,107 @@ async def request_handler(context: BeautifulSoupCrawlingContext) -> None:

assert len(extracted_links) == 1
assert extracted_links[0] == str(server_url / 'page_1')


@pytest.mark.parametrize(
('queue_name', 'queue_alias', 'by_id'),
[
pytest.param('named-queue', None, False, id='with rq_name'),
pytest.param(None, 'alias-queue', False, id='with rq_alias'),
pytest.param('id-queue', None, True, id='with rq_id'),
],
)
async def test_enqueue_links_with_rq_param(
server_url: URL, http_client: HttpClient, queue_name: str | None, queue_alias: str | None, *, by_id: bool
) -> None:
crawler = BeautifulSoupCrawler(http_client=http_client)
rq = await RequestQueue.open(name=queue_name, alias=queue_alias)
if by_id:
queue_name = None
queue_id = rq.id
else:
queue_id = None
visit_urls: set[str] = set()

@crawler.router.default_handler
async def handler(context: BeautifulSoupCrawlingContext) -> None:
visit_urls.add(context.request.url)
await context.enqueue_links(rq_id=queue_id, rq_name=queue_name, rq_alias=queue_alias)

await crawler.run([str(server_url / 'start_enqueue')])

requests_from_queue: list[str] = []
while request := await rq.fetch_next_request():
requests_from_queue.append(request.url)

assert set(requests_from_queue) == {str(server_url / 'page_1'), str(server_url / 'sub_index')}
assert visit_urls == {str(server_url / 'start_enqueue')}

await rq.drop()


@pytest.mark.parametrize(
('queue_name', 'queue_alias', 'by_id'),
[
pytest.param('named-queue', None, False, id='with rq_name'),
pytest.param(None, 'alias-queue', False, id='with rq_alias'),
pytest.param('id-queue', None, True, id='with rq_id'),
],
)
async def test_enqueue_links_requests_with_rq_param(
server_url: URL, http_client: HttpClient, queue_name: str | None, queue_alias: str | None, *, by_id: bool
) -> None:
crawler = BeautifulSoupCrawler(http_client=http_client)
rq = await RequestQueue.open(name=queue_name, alias=queue_alias)
if by_id:
queue_name = None
queue_id = rq.id
else:
queue_id = None
visit_urls: set[str] = set()

check_requests: list[str] = [
'https://a.placeholder.com',
'https://b.placeholder.com',
'https://c.placeholder.com',
]

@crawler.router.default_handler
async def handler(context: BeautifulSoupCrawlingContext) -> None:
visit_urls.add(context.request.url)
await context.enqueue_links(
requests=check_requests, rq_name=queue_name, rq_alias=queue_alias, rq_id=queue_id, strategy='all'
)

await crawler.run([str(server_url / 'start_enqueue')])

requests_from_queue: list[str] = []
while request := await rq.fetch_next_request():
requests_from_queue.append(request.url)

assert set(requests_from_queue) == set(check_requests)
assert visit_urls == {str(server_url / 'start_enqueue')}

await rq.drop()


@pytest.mark.parametrize(
('queue_id', 'queue_name', 'queue_alias'),
[
pytest.param('named-queue', 'alias-queue', None, id='rq_name and rq_alias'),
pytest.param('named-queue', None, 'id-queue', id='rq_name and rq_id'),
pytest.param(None, 'alias-queue', 'id-queue', id='rq_alias and rq_id'),
pytest.param('named-queue', 'alias-queue', 'id-queue', id='rq_name and rq_alias and rq_id'),
],
)
async def test_enqueue_links_error_with_multi_params(
server_url: URL, http_client: HttpClient, queue_id: str | None, queue_name: str | None, queue_alias: str | None
) -> None:
crawler = BeautifulSoupCrawler(http_client=http_client)

@crawler.router.default_handler
async def handler(context: BeautifulSoupCrawlingContext) -> None:
with pytest.raises(ValueError, match='Cannot use both `rq_name` and `rq_alias`'):
await context.enqueue_links(rq_id=queue_id, rq_name=queue_name, rq_alias=queue_alias)

await crawler.run([str(server_url / 'start_enqueue')])
Loading
Loading