From e5a4c6a570fc874cb5f9dc7f1ead58be5777ecba Mon Sep 17 00:00:00 2001 From: Prathamesh Gawas <prathameshgawas87@gmail.com> Date: Wed, 30 Oct 2024 17:10:10 +0000 Subject: [PATCH] feat: Add max_crawl_depth --- src/crawlee/_request.py | 12 ++++++++ src/crawlee/basic_crawler/_basic_crawler.py | 30 ++++++++++++++----- .../unit/basic_crawler/test_basic_crawler.py | 29 ++++++++++++++++++ 3 files changed, 63 insertions(+), 8 deletions(-) diff --git a/src/crawlee/_request.py b/src/crawlee/_request.py index cce5dd063f..8cbf6834d4 100644 --- a/src/crawlee/_request.py +++ b/src/crawlee/_request.py @@ -61,6 +61,9 @@ class CrawleeRequestData(BaseModel): forefront: Annotated[bool, Field()] = False """Indicate whether the request should be enqueued at the front of the queue.""" + crawl_depth: Annotated[int, Field(alias='crawlDepth')] = 0 + """The depth of the request in the crawl tree.""" + class UserData(BaseModel, MutableMapping[str, JsonSerializable]): """Represents the `user_data` part of a Request. @@ -360,6 +363,15 @@ def crawlee_data(self) -> CrawleeRequestData: return user_data.crawlee_data + @property + def crawl_depth(self) -> int: + """The depth of the request in the crawl tree.""" + return self.crawlee_data.crawl_depth + + @crawl_depth.setter + def crawl_depth(self, new_value: int) -> None: + self.crawlee_data.crawl_depth = new_value + @property def state(self) -> RequestState | None: """Crawlee-specific request handling state.""" diff --git a/src/crawlee/basic_crawler/_basic_crawler.py b/src/crawlee/basic_crawler/_basic_crawler.py index 98af31d036..beca0da274 100644 --- a/src/crawlee/basic_crawler/_basic_crawler.py +++ b/src/crawlee/basic_crawler/_basic_crawler.py @@ -120,6 +120,10 @@ class BasicCrawlerOptions(TypedDict, Generic[TCrawlingContext]): configure_logging: NotRequired[bool] """If True, the crawler will set up logging infrastructure automatically.""" + max_crawl_depth: NotRequired[int | None] + """Limits crawl depth from 0 (initial requests) up to the specified `max_crawl_depth`. + Requests at the maximum depth are processed, but no further links are enqueued.""" + _context_pipeline: NotRequired[ContextPipeline[TCrawlingContext]] """Enables extending the request lifecycle and modifying the crawling context. Intended for use by subclasses rather than direct instantiation of `BasicCrawler`.""" @@ -174,6 +178,7 @@ def __init__( statistics: Statistics | None = None, event_manager: EventManager | None = None, configure_logging: bool = True, + max_crawl_depth: int | None = None, _context_pipeline: ContextPipeline[TCrawlingContext] | None = None, _additional_context_managers: Sequence[AsyncContextManager] | None = None, _logger: logging.Logger | None = None, @@ -201,6 +206,7 @@ def __init__( statistics: A custom `Statistics` instance, allowing the use of non-default configuration. event_manager: A custom `EventManager` instance, allowing the use of non-default configuration. configure_logging: If True, the crawler will set up logging infrastructure automatically. + max_crawl_depth: Maximum crawl depth. If set, the crawler will stop crawling after reaching this depth. _context_pipeline: Enables extending the request lifecycle and modifying the crawling context. Intended for use by subclasses rather than direct instantiation of `BasicCrawler`. _additional_context_managers: Additional context managers used throughout the crawler lifecycle. @@ -283,6 +289,7 @@ def __init__( self._running = False self._has_finished_before = False + self._max_crawl_depth = max_crawl_depth @property def log(self) -> logging.Logger: @@ -787,14 +794,21 @@ async def _commit_request_handler_result( else: dst_request = Request.from_base_request_data(request) - if self._check_enqueue_strategy( - add_requests_call.get('strategy', EnqueueStrategy.ALL), - target_url=urlparse(dst_request.url), - origin_url=urlparse(origin), - ) and self._check_url_patterns( - dst_request.url, - add_requests_call.get('include', None), - add_requests_call.get('exclude', None), + # Update the crawl depth of the request. + dst_request.crawl_depth = context.request.crawl_depth + 1 + + if ( + (self._max_crawl_depth is None or dst_request.crawl_depth <= self._max_crawl_depth) + and self._check_enqueue_strategy( + add_requests_call.get('strategy', EnqueueStrategy.ALL), + target_url=urlparse(dst_request.url), + origin_url=urlparse(origin), + ) + and self._check_url_patterns( + dst_request.url, + add_requests_call.get('include', None), + add_requests_call.get('exclude', None), + ) ): requests.append(dst_request) diff --git a/tests/unit/basic_crawler/test_basic_crawler.py b/tests/unit/basic_crawler/test_basic_crawler.py index d62ef3022f..a4e2afd7d2 100644 --- a/tests/unit/basic_crawler/test_basic_crawler.py +++ b/tests/unit/basic_crawler/test_basic_crawler.py @@ -654,6 +654,35 @@ async def handler(context: BasicCrawlingContext) -> None: assert stats.requests_finished == 3 +async def test_max_crawl_depth(httpbin: str) -> None: + processed_urls = [] + + start_request = Request.from_url('https://someplace.com/', label='start') + start_request.crawl_depth = 2 + + # Set max_concurrency to 1 to ensure testing max_requests_per_crawl accurately + crawler = BasicCrawler( + concurrency_settings=ConcurrencySettings(max_concurrency=1), + max_crawl_depth=2, + request_provider=RequestList([start_request]), + ) + + @crawler.router.handler('start') + async def start_handler(context: BasicCrawlingContext) -> None: + processed_urls.append(context.request.url) + await context.add_requests(['https://someplace.com/too-deep']) + + @crawler.router.default_handler + async def handler(context: BasicCrawlingContext) -> None: + processed_urls.append(context.request.url) + + stats = await crawler.run() + + assert len(processed_urls) == 1 + assert stats.requests_total == 1 + assert stats.requests_finished == 1 + + def test_crawler_log() -> None: crawler = BasicCrawler() assert isinstance(crawler.log, logging.Logger)