From e5a4c6a570fc874cb5f9dc7f1ead58be5777ecba Mon Sep 17 00:00:00 2001
From: Prathamesh Gawas <prathameshgawas87@gmail.com>
Date: Wed, 30 Oct 2024 17:10:10 +0000
Subject: [PATCH] feat: Add max_crawl_depth

---
 src/crawlee/_request.py                       | 12 ++++++++
 src/crawlee/basic_crawler/_basic_crawler.py   | 30 ++++++++++++++-----
 .../unit/basic_crawler/test_basic_crawler.py  | 29 ++++++++++++++++++
 3 files changed, 63 insertions(+), 8 deletions(-)

diff --git a/src/crawlee/_request.py b/src/crawlee/_request.py
index cce5dd063f..8cbf6834d4 100644
--- a/src/crawlee/_request.py
+++ b/src/crawlee/_request.py
@@ -61,6 +61,9 @@ class CrawleeRequestData(BaseModel):
     forefront: Annotated[bool, Field()] = False
     """Indicate whether the request should be enqueued at the front of the queue."""
 
+    crawl_depth: Annotated[int, Field(alias='crawlDepth')] = 0
+    """The depth of the request in the crawl tree."""
+
 
 class UserData(BaseModel, MutableMapping[str, JsonSerializable]):
     """Represents the `user_data` part of a Request.
@@ -360,6 +363,15 @@ def crawlee_data(self) -> CrawleeRequestData:
 
         return user_data.crawlee_data
 
+    @property
+    def crawl_depth(self) -> int:
+        """The depth of the request in the crawl tree."""
+        return self.crawlee_data.crawl_depth
+
+    @crawl_depth.setter
+    def crawl_depth(self, new_value: int) -> None:
+        self.crawlee_data.crawl_depth = new_value
+
     @property
     def state(self) -> RequestState | None:
         """Crawlee-specific request handling state."""
diff --git a/src/crawlee/basic_crawler/_basic_crawler.py b/src/crawlee/basic_crawler/_basic_crawler.py
index 98af31d036..beca0da274 100644
--- a/src/crawlee/basic_crawler/_basic_crawler.py
+++ b/src/crawlee/basic_crawler/_basic_crawler.py
@@ -120,6 +120,10 @@ class BasicCrawlerOptions(TypedDict, Generic[TCrawlingContext]):
     configure_logging: NotRequired[bool]
     """If True, the crawler will set up logging infrastructure automatically."""
 
+    max_crawl_depth: NotRequired[int | None]
+    """Limits crawl depth from 0 (initial requests) up to the specified `max_crawl_depth`.
+    Requests at the maximum depth are processed, but no further links are enqueued."""
+
     _context_pipeline: NotRequired[ContextPipeline[TCrawlingContext]]
     """Enables extending the request lifecycle and modifying the crawling context. Intended for use by
     subclasses rather than direct instantiation of `BasicCrawler`."""
@@ -174,6 +178,7 @@ def __init__(
         statistics: Statistics | None = None,
         event_manager: EventManager | None = None,
         configure_logging: bool = True,
+        max_crawl_depth: int | None = None,
         _context_pipeline: ContextPipeline[TCrawlingContext] | None = None,
         _additional_context_managers: Sequence[AsyncContextManager] | None = None,
         _logger: logging.Logger | None = None,
@@ -201,6 +206,7 @@ def __init__(
             statistics: A custom `Statistics` instance, allowing the use of non-default configuration.
             event_manager: A custom `EventManager` instance, allowing the use of non-default configuration.
             configure_logging: If True, the crawler will set up logging infrastructure automatically.
+            max_crawl_depth: Maximum crawl depth. If set, the crawler will stop crawling after reaching this depth.
             _context_pipeline: Enables extending the request lifecycle and modifying the crawling context.
                 Intended for use by subclasses rather than direct instantiation of `BasicCrawler`.
             _additional_context_managers: Additional context managers used throughout the crawler lifecycle.
@@ -283,6 +289,7 @@ def __init__(
 
         self._running = False
         self._has_finished_before = False
+        self._max_crawl_depth = max_crawl_depth
 
     @property
     def log(self) -> logging.Logger:
@@ -787,14 +794,21 @@ async def _commit_request_handler_result(
                 else:
                     dst_request = Request.from_base_request_data(request)
 
-                if self._check_enqueue_strategy(
-                    add_requests_call.get('strategy', EnqueueStrategy.ALL),
-                    target_url=urlparse(dst_request.url),
-                    origin_url=urlparse(origin),
-                ) and self._check_url_patterns(
-                    dst_request.url,
-                    add_requests_call.get('include', None),
-                    add_requests_call.get('exclude', None),
+                # Update the crawl depth of the request.
+                dst_request.crawl_depth = context.request.crawl_depth + 1
+
+                if (
+                    (self._max_crawl_depth is None or dst_request.crawl_depth <= self._max_crawl_depth)
+                    and self._check_enqueue_strategy(
+                        add_requests_call.get('strategy', EnqueueStrategy.ALL),
+                        target_url=urlparse(dst_request.url),
+                        origin_url=urlparse(origin),
+                    )
+                    and self._check_url_patterns(
+                        dst_request.url,
+                        add_requests_call.get('include', None),
+                        add_requests_call.get('exclude', None),
+                    )
                 ):
                     requests.append(dst_request)
 
diff --git a/tests/unit/basic_crawler/test_basic_crawler.py b/tests/unit/basic_crawler/test_basic_crawler.py
index d62ef3022f..a4e2afd7d2 100644
--- a/tests/unit/basic_crawler/test_basic_crawler.py
+++ b/tests/unit/basic_crawler/test_basic_crawler.py
@@ -654,6 +654,35 @@ async def handler(context: BasicCrawlingContext) -> None:
     assert stats.requests_finished == 3
 
 
+async def test_max_crawl_depth(httpbin: str) -> None:
+    processed_urls = []
+
+    start_request = Request.from_url('https://someplace.com/', label='start')
+    start_request.crawl_depth = 2
+
+    # Set max_concurrency to 1 to ensure testing max_requests_per_crawl accurately
+    crawler = BasicCrawler(
+        concurrency_settings=ConcurrencySettings(max_concurrency=1),
+        max_crawl_depth=2,
+        request_provider=RequestList([start_request]),
+    )
+
+    @crawler.router.handler('start')
+    async def start_handler(context: BasicCrawlingContext) -> None:
+        processed_urls.append(context.request.url)
+        await context.add_requests(['https://someplace.com/too-deep'])
+
+    @crawler.router.default_handler
+    async def handler(context: BasicCrawlingContext) -> None:
+        processed_urls.append(context.request.url)
+
+    stats = await crawler.run()
+
+    assert len(processed_urls) == 1
+    assert stats.requests_total == 1
+    assert stats.requests_finished == 1
+
+
 def test_crawler_log() -> None:
     crawler = BasicCrawler()
     assert isinstance(crawler.log, logging.Logger)