Skip to content

Commit

Permalink
feat: Add max_crawl_depth
Browse files Browse the repository at this point in the history
  • Loading branch information
Prathamesh010 committed Oct 31, 2024
1 parent 7d75289 commit e44360d
Show file tree
Hide file tree
Showing 8 changed files with 99 additions and 7 deletions.
8 changes: 8 additions & 0 deletions src/crawlee/_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ class CrawleeRequestData(BaseModel):
forefront: Annotated[bool, Field()] = False
"""Indicate whether the request should be enqueued at the front of the queue."""

crawl_depth: Annotated[int, Field(alias='crawlDepth')] = 0
"""The depth of the request in the crawl tree."""


class UserData(BaseModel, MutableMapping[str, JsonSerializable]):
"""Represents the `user_data` part of a Request.
Expand Down Expand Up @@ -360,6 +363,11 @@ def crawlee_data(self) -> CrawleeRequestData:

return user_data.crawlee_data

@property
def crawl_depth(self) -> int:
"""The depth of the request in the crawl tree."""
return self.crawlee_data.crawl_depth

@property
def state(self) -> RequestState | None:
"""Crawlee-specific request handling state."""
Expand Down
6 changes: 6 additions & 0 deletions src/crawlee/basic_crawler/_basic_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@ class BasicCrawlerOptions(TypedDict, Generic[TCrawlingContext]):
configure_logging: NotRequired[bool]
"""If True, the crawler will set up logging infrastructure automatically."""

max_crawl_depth: NotRequired[int | None]
"""Maximum crawl depth. If set, the crawler will stop crawling after reaching this depth."""

_context_pipeline: NotRequired[ContextPipeline[TCrawlingContext]]
"""Enables extending the request lifecycle and modifying the crawling context. Intended for use by
subclasses rather than direct instantiation of `BasicCrawler`."""
Expand Down Expand Up @@ -174,6 +177,7 @@ def __init__(
statistics: Statistics | None = None,
event_manager: EventManager | None = None,
configure_logging: bool = True,
max_crawl_depth: int | None = None,
_context_pipeline: ContextPipeline[TCrawlingContext] | None = None,
_additional_context_managers: Sequence[AsyncContextManager] | None = None,
_logger: logging.Logger | None = None,
Expand Down Expand Up @@ -201,6 +205,7 @@ def __init__(
statistics: A custom `Statistics` instance, allowing the use of non-default configuration.
event_manager: A custom `EventManager` instance, allowing the use of non-default configuration.
configure_logging: If True, the crawler will set up logging infrastructure automatically.
max_crawl_depth: Maximum crawl depth. If set, the crawler will stop crawling after reaching this depth.
_context_pipeline: Enables extending the request lifecycle and modifying the crawling context.
Intended for use by subclasses rather than direct instantiation of `BasicCrawler`.
_additional_context_managers: Additional context managers used throughout the crawler lifecycle.
Expand Down Expand Up @@ -283,6 +288,7 @@ def __init__(

self._running = False
self._has_finished_before = False
self._max_crawl_depth = max_crawl_depth

@property
def log(self) -> logging.Logger:
Expand Down
11 changes: 10 additions & 1 deletion src/crawlee/beautifulsoup_crawler/_beautifulsoup_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from typing_extensions import Unpack

from crawlee import EnqueueStrategy
from crawlee._request import BaseRequestData
from crawlee._request import BaseRequestData, CrawleeRequestData
from crawlee._utils.blocked import RETRY_CSS_SELECTORS
from crawlee._utils.urls import convert_to_absolute_url, is_url_absolute
from crawlee.basic_crawler import BasicCrawler, BasicCrawlerOptions, ContextPipeline
Expand Down Expand Up @@ -181,6 +181,12 @@ async def enqueue_links(
) -> None:
kwargs.setdefault('strategy', EnqueueStrategy.SAME_HOSTNAME)

if self._max_crawl_depth is not None and context.request.crawl_depth + 1 > self._max_crawl_depth:
context.log.info(
f'Skipping enqueue_links for URL "{context.request.url}" due to the maximum crawl depth limit.'
)
return

requests = list[BaseRequestData]()
user_data = user_data or {}

Expand All @@ -191,6 +197,9 @@ async def enqueue_links(
if label is not None:
link_user_data.setdefault('label', label)

data = {'crawlDepth': context.request.crawl_depth + 1}
link_user_data.setdefault('__crawlee', CrawleeRequestData(**data))

if (url := link.attrs.get('href')) is not None:
url = url.strip()

Expand Down
11 changes: 10 additions & 1 deletion src/crawlee/parsel_crawler/_parsel_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from typing_extensions import Unpack

from crawlee import EnqueueStrategy
from crawlee._request import BaseRequestData
from crawlee._request import BaseRequestData, CrawleeRequestData
from crawlee._utils.blocked import RETRY_CSS_SELECTORS
from crawlee._utils.urls import convert_to_absolute_url, is_url_absolute
from crawlee.basic_crawler import BasicCrawler, BasicCrawlerOptions, ContextPipeline
Expand Down Expand Up @@ -180,6 +180,12 @@ async def enqueue_links(
) -> None:
kwargs.setdefault('strategy', EnqueueStrategy.SAME_HOSTNAME)

if self._max_crawl_depth is not None and context.request.crawl_depth + 1 > self._max_crawl_depth:
context.log.info(
f'Skipping enqueue_links for URL "{context.request.url}" due to the maximum crawl depth limit.'
)
return

requests = list[BaseRequestData]()
user_data = user_data or {}

Expand All @@ -190,6 +196,9 @@ async def enqueue_links(
if label is not None:
link_user_data.setdefault('label', label)

data = {'crawlDepth': context.request.crawl_depth + 1}
link_user_data.setdefault('__crawlee', CrawleeRequestData(**data))

if (url := link.xpath('@href').get()) is not None:
url = url.strip()

Expand Down
11 changes: 10 additions & 1 deletion src/crawlee/playwright_crawler/_playwright_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing_extensions import Unpack

from crawlee import EnqueueStrategy
from crawlee._request import BaseRequestData
from crawlee._request import BaseRequestData, CrawleeRequestData
from crawlee._utils.blocked import RETRY_CSS_SELECTORS
from crawlee._utils.urls import convert_to_absolute_url, is_url_absolute
from crawlee.basic_crawler import BasicCrawler, BasicCrawlerOptions, ContextPipeline
Expand Down Expand Up @@ -168,6 +168,12 @@ async def enqueue_links(
requests = list[BaseRequestData]()
user_data = user_data or {}

if self._max_crawl_depth is not None and context.request.crawl_depth + 1 > self._max_crawl_depth:
context.log.info(
f'Skipping enqueue_links for URL "{context.request.url}" due to the maximum crawl depth limit.'
)
return

elements = await context.page.query_selector_all(selector)

for element in elements:
Expand All @@ -184,6 +190,9 @@ async def enqueue_links(
if label is not None:
link_user_data.setdefault('label', label)

data = {'crawlDepth': context.request.crawl_depth + 1}
link_user_data.setdefault('__crawlee', CrawleeRequestData(**data))

try:
request = BaseRequestData.from_url(url, user_data=link_user_data)
except ValidationError as exc:
Expand Down
18 changes: 17 additions & 1 deletion tests/unit/beautifulsoup_crawler/test_beautifulsoup_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import respx
from httpx import Response

from crawlee import ConcurrencySettings
from crawlee import ConcurrencySettings, Glob
from crawlee.beautifulsoup_crawler import BeautifulSoupCrawler
from crawlee.storages import RequestList

Expand Down Expand Up @@ -165,3 +165,19 @@ async def test_handle_blocked_request(server: respx.MockRouter) -> None:
stats = await crawler.run()
assert server['incapsula_endpoint'].called
assert stats.requests_failed == 1


async def test_enqueue_links_skips_when_crawl_depth_exceeded() -> None:
crawler = BeautifulSoupCrawler(max_crawl_depth=0)
visit = mock.Mock()

@crawler.router.default_handler
async def request_handler(context: BeautifulSoupCrawlingContext) -> None:
visit(context.request.url)
await context.enqueue_links(include=[Glob('https://crawlee.dev/docs/examples/**')])

await crawler.run(['https://crawlee.dev/docs/examples'])

visited = {call[0][0] for call in visit.call_args_list}

assert len(visited) == 1
25 changes: 22 additions & 3 deletions tests/unit/parsel_crawler/test_parsel_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import respx
from httpx import Response

from crawlee import ConcurrencySettings
from crawlee._request import BaseRequestData
from crawlee import ConcurrencySettings, Glob
from crawlee._request import BaseRequestData, CrawleeRequestData
from crawlee.parsel_crawler import ParselCrawler
from crawlee.storages import RequestList

Expand Down Expand Up @@ -171,7 +171,10 @@ async def request_handler(context: ParselCrawlingContext) -> None:
}

assert from_url.call_count == 1
assert from_url.call_args == (('https://test.io/asdf',), {'user_data': {'label': 'foo'}})
assert from_url.call_args == (
('https://test.io/asdf',),
{'user_data': {'label': 'foo', '__crawlee': CrawleeRequestData(crawlDepth=1)}},
)


async def test_enqueue_links_with_max_crawl(server: respx.MockRouter) -> None:
Expand Down Expand Up @@ -281,3 +284,19 @@ async def request_handler(context: ParselCrawlingContext) -> None:
assert handler.called

assert handler.call_args[0][0] == ['<hello>world</hello>']


async def test_enqueue_links_skips_when_crawl_depth_exceeded() -> None:
crawler = ParselCrawler(max_crawl_depth=0)
visit = mock.Mock()

@crawler.router.default_handler
async def request_handler(context: ParselCrawlingContext) -> None:
visit(context.request.url)
await context.enqueue_links(include=[Glob('https://crawlee.dev/docs/examples/**')])

await crawler.run(['https://crawlee.dev/docs/examples'])

visited = {call[0][0] for call in visit.call_args_list}

assert len(visited) == 1
16 changes: 16 additions & 0 deletions tests/unit/playwright_crawler/test_playwright_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,19 @@ async def request_handler(_context: PlaywrightCrawlingContext) -> None:
await crawler.run(['https://example.com', 'https://httpbin.org'])

assert mock_hook.call_count == 2


async def test_enqueue_links_skips_when_crawl_depth_exceeded() -> None:
crawler = PlaywrightCrawler(max_crawl_depth=0)
visit = mock.Mock()

@crawler.router.default_handler
async def request_handler(context: PlaywrightCrawlingContext) -> None:
visit(context.request.url)
await context.enqueue_links(include=[Glob('https://crawlee.dev/docs/examples/**')])

await crawler.run(['https://crawlee.dev/docs/examples'])

visited = {call[0][0] for call in visit.call_args_list}

assert len(visited) == 1

0 comments on commit e44360d

Please sign in to comment.