From cfe905b5529b0d69a57389d625c906ce6fbe0ff2 Mon Sep 17 00:00:00 2001 From: Jan Buchar Date: Tue, 13 Aug 2024 19:25:28 +0200 Subject: [PATCH 01/10] Carry over RequestQueueV2 logic from JS --- src/crawlee/_utils/lru_cache.py | 2 +- src/crawlee/storages/request_queue.py | 394 ++++++++++++++------------ 2 files changed, 218 insertions(+), 178 deletions(-) diff --git a/src/crawlee/_utils/lru_cache.py b/src/crawlee/_utils/lru_cache.py index 057ed34b6..1d0d848ee 100644 --- a/src/crawlee/_utils/lru_cache.py +++ b/src/crawlee/_utils/lru_cache.py @@ -8,7 +8,7 @@ T = TypeVar('T') -class LRUCache(MutableMapping, Generic[T]): +class LRUCache(MutableMapping[str, T], Generic[T]): """Attempt to reimplement LRUCache from `@apify/datastructures` using `OrderedDict`.""" def __init__(self, max_length: int) -> None: diff --git a/src/crawlee/storages/request_queue.py b/src/crawlee/storages/request_queue.py index 51ad0d204..99cb8c112 100644 --- a/src/crawlee/storages/request_queue.py +++ b/src/crawlee/storages/request_queue.py @@ -2,9 +2,10 @@ import asyncio from collections import OrderedDict +from contextlib import suppress from datetime import datetime, timedelta, timezone from logging import getLogger -from typing import TYPE_CHECKING, Generic, TypedDict, TypeVar +from typing import TYPE_CHECKING, Any, Generic, TypedDict, TypeVar from typing_extensions import override @@ -12,11 +13,12 @@ from crawlee._utils.lru_cache import LRUCache from crawlee._utils.requests import unique_key_to_request_id from crawlee._utils.wait import wait_for_all_tasks_for_finish +from crawlee.events.event_manager import EventManager +from crawlee.events.types import Event from crawlee.models import ( BaseRequestData, ProcessedRequest, Request, - RequestQueueHeadState, RequestQueueMetadata, ) from crawlee.storages.base_storage import BaseStorage @@ -28,10 +30,9 @@ from crawlee.base_storage_client import BaseStorageClient from crawlee.configuration import Configuration -logger = getLogger(__name__) - __all__ = ['RequestQueue'] +logger = getLogger(__name__) T = TypeVar('T') @@ -63,6 +64,8 @@ def clear(self) -> None: class CachedRequest(TypedDict): id: str was_already_handled: bool + hydrated: Request | None + lock_expires_at: datetime | None class RequestQueue(BaseStorage, RequestProvider): @@ -85,24 +88,9 @@ class RequestQueue(BaseStorage, RequestProvider): rq = await RequestQueue.open(id='my_rq_id') """ - _API_PROCESSED_REQUESTS_DELAY = timedelta(seconds=10) - """Delay threshold to assume consistency of queue head operations after queue modifications.""" - _MAX_CACHED_REQUESTS = 1_000_000 """Maximum number of requests that can be cached.""" - _MAX_HEAD_LIMIT = 1000 - """Cap on requests in progress when querying queue head.""" - - _MAX_QUERIES_FOR_CONSISTENCY = 6 - """Maximum attempts to fetch a consistent queue head.""" - - _QUERY_HEAD_BUFFER = 3 - """Multiplier for determining the number of requests to fetch based on in-progress requests.""" - - _QUERY_HEAD_MIN_LENGTH = 100 - """The minimum number of requests fetched when querying the queue head.""" - _RECENTLY_HANDLED_CACHE_SIZE = 1000 """Cache size for recently handled requests.""" @@ -115,6 +103,7 @@ def __init__( name: str | None, configuration: Configuration, client: BaseStorageClient, + event_manager: EventManager, ) -> None: self._id = id self._name = name @@ -124,14 +113,21 @@ def __init__( self._resource_client = client.request_queue(self._id) self._resource_collection_client = client.request_queues() + self._request_lock_time = timedelta(minutes=3) + self._queue_paused_for_migration = False + + event_manager.on(event=Event.MIGRATING, listener=lambda _: setattr(self, '_queue_paused_for_migration', True)) + event_manager.on(event=Event.MIGRATING, listener=lambda _: self._clear_possible_locks()) + event_manager.on(event=Event.ABORTING, listener=lambda _: self._clear_possible_locks()) + # Other internal attributes self._tasks = list[asyncio.Task]() self._client_key = crypto_random_object_id() - self._internal_timeout_seconds = 5 * 60 + self._internal_timeout = configuration.internal_timeout or timedelta(minutes=5) self._assumed_total_count = 0 self._assumed_handled_count = 0 self._queue_head_dict: OrderedDict[str, str] = OrderedDict() - self._query_queue_head_task: asyncio.Task | None = None + self._list_head_and_lock_task: asyncio.Task | None = None self._in_progress: set[str] = set() self._last_activity = datetime.now(timezone.utc) self._recently_handled: BoundedSet[str] = BoundedSet(max_length=self._RECENTLY_HANDLED_CACHE_SIZE) @@ -210,11 +206,7 @@ async def add_request( use_extended_unique_key: Determines whether to use an extended unique key, incorporating the request's method and payload into the unique key computation. - Returns: A dictionary containing information about the operation, including: - - `requestId` The ID of the request. - - `uniqueKey` The unique key associated with the request. - - `wasAlreadyPresent` (bool): Indicates whether the request was already in the queue. - - `wasAlreadyHandled` (bool): Indicates whether the request was already processed. + Returns: Information about the processed request """ request = self._transform_request(request) self._last_activity = datetime.now(timezone.utc) @@ -248,7 +240,6 @@ async def add_request( and request_id not in self._recently_handled ): self._assumed_total_count += 1 - self._maybe_add_request_to_queue_head(request_id, forefront=forefront) return processed_request @@ -290,7 +281,7 @@ async def _process_remaining_batches() -> None: # Wait for all tasks to finish if requested if wait_for_all_requests_to_be_added: await wait_for_all_tasks_for_finish( - self._tasks, + (remaining_batches_task,), logger=logger, timeout=wait_for_all_requests_to_be_added_timeout, ) @@ -321,6 +312,8 @@ async def fetch_next_request(self) -> Request | None: Returns: The request or `None` if there are no more pending requests. """ + self._last_activity = datetime.now(timezone.utc) + await self._ensure_head_is_non_empty() # We are likely done at this point. @@ -340,11 +333,11 @@ async def fetch_next_request(self) -> Request | None: }, ) return None + self._in_progress.add(next_request_id) - self._last_activity = datetime.now(timezone.utc) try: - request = await self.get_request(next_request_id) + request = await self._get_or_hydrate_request(next_request_id) except Exception: # On error, remove the request from in progress, otherwise it would be there forever self._in_progress.remove(next_request_id) @@ -447,19 +440,18 @@ async def reclaim_request( processed_request.unique_key = request.unique_key self._cache_request(unique_key_to_request_id(request.unique_key), processed_request) - # Wait a little to increase a chance that the next call to fetchNextRequest() will return the request with - # updated data. This is to compensate for the limitation of DynamoDB, where writes might not be immediately - # visible to subsequent reads. - def callback() -> None: - if request.id not in self._in_progress: - logger.debug(f'The request (ID: {request.id}) is no longer marked as in progress in the queue?!') - return + if processed_request: + # Mark the request as no longer in progress, + # as the moment we delete the lock, we could end up also re-fetching the request in a subsequent + # _ensure_head_is_non_empty() which could potentially lock the request again + self._in_progress.discard(request.id) - self._in_progress.remove(request.id) - # Performance optimization: add request straight to head if possible - self._maybe_add_request_to_queue_head(request.id, forefront=forefront) + # Try to delete the request lock if possible + try: + await self._resource_client.delete_request_lock(request.id, forefront=forefront) + except Exception as err: + logger.debug(f'Failed to delete request lock for request {request.id}', exc_info=err) - asyncio.get_running_loop().call_later(self._STORAGE_CONSISTENCY_DELAY.total_seconds(), callback) return processed_request async def is_empty(self) -> bool: @@ -481,21 +473,49 @@ async def is_finished(self) -> bool: Returns: bool: `True` if all requests were already handled and there are no more left. `False` otherwise. """ - seconds_since_last_activity = (datetime.now(timezone.utc) - self._last_activity).total_seconds() - if self._in_progress_count() > 0 and seconds_since_last_activity > self._internal_timeout_seconds: - message = ( - f'The request queue seems to be stuck for {self._internal_timeout_seconds}s, resetting internal state.' + seconds_since_last_activity = datetime.now(timezone.utc) - self._last_activity + if self._in_progress_count() > 0 and seconds_since_last_activity > self._internal_timeout: + logger.warning( + f'The request queue seems to be stuck for {self._internal_timeout.total_seconds()}s, ' + 'resetting internal state.', + extra={ + 'queue_head_ids_pending': len(self._queue_head_dict), + 'in_progress': list(self._in_progress), + }, + ) + + # We only need to reset these two variables, no need to reset all the other stats + self._queue_head_dict.clear() + self._in_progress.clear() + + if self._queue_head_dict: + logger.debug( + 'There are still ids in the queue head that are pending processing', + extra={ + 'queue_head_ids_pending': len(self._queue_head_dict), + }, + ) + + return False + + if self._in_progress: + logger.debug( + 'There are still requests in progress (or zombie)', + extra={ + 'in_progress': list(self._in_progress), + }, ) - logger.warning(message) - self._reset() - if len(self._queue_head_dict) > 0 or self._in_progress_count() > 0: return False - # TODO: set ensure_consistency to True once the following issue is resolved: - # https://github.com/apify/crawlee-python/issues/203 - is_head_consistent = await self._ensure_head_is_non_empty(ensure_consistency=False) - return is_head_consistent and len(self._queue_head_dict) == 0 and self._in_progress_count() == 0 + current_head = await self._resource_client.list_head(limit=2) + + if current_head.items: + logger.debug( + 'Queue head still returned requests that need to be processed (or that are locked by other clients)', + ) + + return not current_head.items and not self._in_progress async def get_info(self) -> RequestQueueMetadata | None: """Get an object containing general information about the request queue.""" @@ -509,102 +529,73 @@ async def get_handled_count(self) -> int: async def get_total_count(self) -> int: return self._assumed_total_count - async def _ensure_head_is_non_empty( - self, - *, - ensure_consistency: bool = False, - limit: int | None = None, - iteration: int = 0, - ) -> bool: - """Ensure that the queue head is non-empty. + async def _ensure_head_is_non_empty(self) -> None: + # Stop fetching if we are paused for migration + if self._queue_paused_for_migration: + return - The method ensures that the queue head contains items. It may request more items than are currently - in progress to guarantee that at least one item is present in the head of the queue. - - Args: - ensure_consistency: If True, the query for the queue head is retried until the queue_modified_at is older - than query_started_at by at least API_PROCESSED_REQUESTS_DELAY to ensure that the queue head is - consistent. - limit: The maximum number of items to fetch from the queue. - iteration: To manage the recursion depth. - - Returns: - True if the queue head is non-empty and consistent, False otherwise. - """ - # If queue head is non-empty, returns True immediately - if len(self._queue_head_dict) > 0: - return True + # We want to fetch ahead of time to minimize dead time + if len(self._queue_head_dict) > 1: + return - if limit is None: - limit = max(self._in_progress_count() * self._QUERY_HEAD_BUFFER, self._QUERY_HEAD_MIN_LENGTH) + if self._list_head_and_lock_task is None: + task = asyncio.create_task(self._list_head_and_lock()) - if self._query_queue_head_task is None: - self._query_queue_head_task = asyncio.Task(self._queue_query_head(limit)) + def callback(_: Any) -> None: + self._list_head_and_lock_task = None - queue_head: RequestQueueHeadState = await self._query_queue_head_task + task.add_done_callback(callback) + self._list_head_and_lock_task = task - # TODO: I feel this code below can be greatly simplified... (comes from TS implementation *wink*) - # https://github.com/apify/apify-sdk-python/issues/142 + await self._list_head_and_lock_task - # If queue is still empty then one of the following holds: - # - the other calls waiting for this task already consumed all the returned requests - # - the limit was too low and contained only requests in progress - # - the writes from other clients were not propagated yet - # - the whole queue was processed and we are done - - # If limit was not reached in the call then there are no more requests to be returned. - if queue_head.prev_limit >= self._MAX_HEAD_LIMIT: - logger.warning(f'Reached the maximum number of requests in progress (limit: {self._MAX_HEAD_LIMIT})') - - should_repeat_with_higher_limit = ( - len(self._queue_head_dict) == 0 - and queue_head.was_limit_reached - and queue_head.prev_limit < self._MAX_HEAD_LIMIT + async def _list_head_and_lock(self) -> None: + response = await self._resource_client.list_and_lock_head( + limit=25, lock_secs=int(self._request_lock_time.total_seconds()) ) - # If ensure_consistency is True, we must ensure the database is consistent. It can be ensured if either: - # - queue_modified_at is older than query_started_at by at least _API_PROCESSED_REQUESTS_DELAY - # - had_multiple_clients is False and _assumed_total_count is less than _assumed_handled_count - queue_latency = queue_head.query_started_at - queue_head.queue_modified_at.replace(tzinfo=timezone.utc) - is_database_consistent = queue_latency.total_seconds() >= self._API_PROCESSED_REQUESTS_DELAY.total_seconds() - - is_locally_consistent = ( - not queue_head.had_multiple_clients and self._assumed_total_count <= self._assumed_handled_count - ) - - # Consistent information from one source is enough to consider request queue finished. - should_repeat_for_consistency = ensure_consistency and not is_database_consistent and not is_locally_consistent - - # If both are false then head is consistent and we may exit. - if not should_repeat_with_higher_limit and not should_repeat_for_consistency: - return True - - # If we are querying for consistency then we limit the number of queries to MAX_QUERIES_FOR_CONSISTENCY. - # If this is reached then we return false so that empty() and finished() returns possibly false negative. - if not should_repeat_with_higher_limit and iteration > self._MAX_QUERIES_FOR_CONSISTENCY: - return False - - next_limit = round(queue_head.prev_limit * 1.5) if should_repeat_with_higher_limit else queue_head.prev_limit + for request in response.items: + # Queue head index might be behind the main table, so ensure we don't recycle requests + if ( + not request.id + or not request.unique_key + or request.id in self._in_progress + or request.id in self._recently_handled + ): + logger.debug( + 'Skipping request from queue head, already in progress or recently handled', + extra={ + 'id': request.id, + 'unique_key': request.unique_key, + 'in_progress': request.id in self._in_progress, + 'recently_handled': request.id in self._recently_handled, + }, + ) + + # Remove the lock from the request for now, so that it can be picked up later + # This may/may not succeed, but that's fine + with suppress(Exception): + await self._resource_client.delete_request_lock(request.id) - # If we are repeating for consistency then wait required time. - if should_repeat_for_consistency: - elapsed_time = (datetime.now(timezone.utc) - queue_head.queue_modified_at).total_seconds() - delay_seconds = self._API_PROCESSED_REQUESTS_DELAY.total_seconds() - elapsed_time - logger.info(f'Waiting for {delay_seconds} for queue finalization, to ensure data consistency.') - await asyncio.sleep(delay_seconds) + continue - return await self._ensure_head_is_non_empty( - ensure_consistency=ensure_consistency, - limit=next_limit, - iteration=iteration + 1, - ) + self._queue_head_dict[request.id] = request.id + self._cache_request( + unique_key_to_request_id(request.unique_key), + ProcessedRequest( + id=request.id, + unique_key=request.unique_key, + was_already_present=True, + was_already_handled=False, + ), + ) def _in_progress_count(self) -> int: return len(self._in_progress) def _reset(self) -> None: self._queue_head_dict.clear() - self._query_queue_head_task = None + self._list_head_and_lock_task = None self._in_progress.clear() self._recently_handled.clear() self._assumed_total_count = 0 @@ -616,56 +607,105 @@ def _cache_request(self, cache_key: str, processed_request: ProcessedRequest) -> self._requests_cache[cache_key] = { 'id': processed_request.id, 'was_already_handled': processed_request.was_already_handled, + 'hydrated': None, + 'lock_expires_at': None, } - async def _queue_query_head(self, limit: int) -> RequestQueueHeadState: - query_started_at = datetime.now(timezone.utc) + async def _get_or_hydrate_request(self, request_id: str) -> Request | None: + cached_entry = self._requests_cache.get(request_id) - list_head = await self._resource_client.list_head(limit=limit) - list_head_items: list[Request] = list_head.items + if not cached_entry: + # 2.1. Attempt to prolong the request lock to see if we still own the request + prolong_result = await self._prolong_request_lock(request_id) - for request in list_head_items: - # Queue head index might be behind the main table, so ensure we don't recycle requests - if ( - not request.id - or not request.unique_key - or request.id in self._in_progress - or request.id in self._recently_handled - ): - continue + if not prolong_result: + return None - self._queue_head_dict[request.id] = request.id - self._cache_request( - cache_key=unique_key_to_request_id(request.unique_key), - processed_request=ProcessedRequest( - id=request.id, - unique_key=request.unique_key, - was_already_handled=False, - was_already_present=True, - ), - ) + # 2.1.1. If successful, hydrate the request and return it + hydrated_request = await self.get_request(request_id) - # This is needed so that the next call to _ensureHeadIsNonEmpty() will fetch the queue head again. - self._query_queue_head_task = None + # Queue head index is ahead of the main table and the request is not present in the main table yet + # (i.e. getRequest() returned null). + if not hydrated_request: + # Remove the lock from the request for now, so that it can be picked up later + # This may/may not succeed, but that's fine + with suppress(Exception): + await self._resource_client.delete_request_lock(request_id) - return RequestQueueHeadState( - was_limit_reached=len(list_head.items) >= limit, - prev_limit=limit, - queue_modified_at=list_head.queue_modified_at, - query_started_at=query_started_at, - had_multiple_clients=list_head.had_multiple_clients, - ) + return None - def _maybe_add_request_to_queue_head( - self, - request_id: str, - *, - forefront: bool, - ) -> None: - if forefront: - self._queue_head_dict[request_id] = request_id - # Move to start, i.e. forefront of the queue - self._queue_head_dict.move_to_end(request_id, last=False) - elif self._assumed_total_count < self._QUERY_HEAD_MIN_LENGTH: - # OrderedDict puts the item to the end of the queue by default - self._queue_head_dict[request_id] = request_id + self._requests_cache[request_id] = { + 'id': request_id, + 'hydrated': hydrated_request, + 'was_already_handled': hydrated_request.handled_at is not None, + 'lock_expires_at': prolong_result, + } + + return hydrated_request + + # 1.1. If hydrated, prolong the lock more and return it + if cached_entry['hydrated']: + # 1.1.1. If the lock expired on the hydrated requests, try to prolong. If we fail, we lost the request + # (or it was handled already) + if cached_entry['lock_expires_at'] and cached_entry['lock_expires_at'] < datetime.now(timezone.utc): + prolonged = await self._prolong_request_lock(cached_entry['id']) + + if not prolonged: + return None + + cached_entry['lock_expires_at'] = prolonged + + return cached_entry['hydrated'] + + # 1.2. If not hydrated, try to prolong the lock first (to ensure we keep it in our queue), hydrate and return it + prolonged = await self._prolong_request_lock(cached_entry['id']) + + if not prolonged: + return None + + # This might still return null if the queue head is inconsistent with the main queue table. + hydrated_request = await self.get_request(cached_entry['id']) + + cached_entry['hydrated'] = hydrated_request + + # Queue head index is ahead of the main table and the request is not present in the main table yet + # (i.e. getRequest() returned null). + if not hydrated_request: + # Remove the lock from the request for now, so that it can be picked up later + # This may/may not succeed, but that's fine + with suppress(Exception): + await self._resource_client.delete_request_lock(cached_entry['id']) + + return None + + return hydrated_request + + async def _prolong_request_lock(self, request_id: str) -> datetime | None: + try: + res = await self._resource_client.prolong_request_lock( + request_id, lock_secs=int(self._request_lock_time.total_seconds()) + ) + except Exception as err: + # Most likely we do not own the lock anymore + logger.warning( + f'Failed to prolong lock for cached request {request_id}, either lost the lock ' + 'or the request was already handled\n', + exc_info=err, + ) + return None + else: + return res.lock_expires_at + + async def _clear_possible_locks(self) -> None: + self._queue_paused_for_migration = True + request_id: str | None = None + + while True: + try: + request_id, _ = self._queue_head_dict.popitem() + except KeyError: + break + + with suppress(Exception): + await self._resource_client.delete_request_lock(request_id) + # If this fails, we don't have the lock, or the request was never locked. Either way it's fine From 8917bd9cc2942f86a48faa9a35381993b8b5f1d5 Mon Sep 17 00:00:00 2001 From: Jan Buchar Date: Wed, 14 Aug 2024 17:28:56 +0200 Subject: [PATCH 02/10] Unify service management --- src/crawlee/basic_crawler/basic_crawler.py | 13 ++- src/crawlee/configuration.py | 17 ++-- src/crawlee/service_container.py | 97 +++++++++++++++++++ src/crawlee/statistics/statistics.py | 4 +- src/crawlee/storage_client_manager.py | 50 ---------- src/crawlee/storages/_creation_management.py | 27 ++++-- src/crawlee/storages/dataset.py | 4 +- src/crawlee/storages/key_value_store.py | 3 + src/crawlee/storages/request_queue.py | 5 +- src/crawlee/types.py | 3 + tests/unit/conftest.py | 12 +-- .../test_memory_storage_e2e.py | 6 +- tests/unit/test_storage_client_manager.py | 33 ------- 13 files changed, 155 insertions(+), 119 deletions(-) create mode 100644 src/crawlee/service_container.py delete mode 100644 src/crawlee/storage_client_manager.py delete mode 100644 tests/unit/test_storage_client_manager.py diff --git a/src/crawlee/basic_crawler/basic_crawler.py b/src/crawlee/basic_crawler/basic_crawler.py index e12b8d9cf..53de93a76 100644 --- a/src/crawlee/basic_crawler/basic_crawler.py +++ b/src/crawlee/basic_crawler/basic_crawler.py @@ -18,6 +18,7 @@ from tldextract import TLDExtract from typing_extensions import NotRequired, TypedDict, TypeVar, Unpack, assert_never +import crawlee.service_container from crawlee import Glob from crawlee._utils.urls import convert_to_absolute_url, is_url_absolute from crawlee._utils.wait import wait_for @@ -26,7 +27,6 @@ from crawlee.autoscaling.system_status import SystemStatus from crawlee.basic_crawler.context_pipeline import ContextPipeline from crawlee.basic_crawler.router import Router -from crawlee.configuration import Configuration from crawlee.enqueue_strategy import EnqueueStrategy from crawlee.errors import ( ContextPipelineInitializationError, @@ -35,7 +35,6 @@ SessionError, UserDefinedErrorHandlerError, ) -from crawlee.events import LocalEventManager from crawlee.http_clients import HttpxHttpClient from crawlee.log_config import CrawleeLogFormatter from crawlee.models import BaseRequestData, DatasetItemsListPage, Request, RequestState @@ -47,6 +46,8 @@ if TYPE_CHECKING: import re + from crawlee.configuration import Configuration + from crawlee.events.event_manager import EventManager from crawlee.http_clients import BaseHttpClient, HttpResponse from crawlee.proxy_configuration import ProxyConfiguration, ProxyInfo from crawlee.sessions import Session @@ -77,6 +78,7 @@ class BasicCrawlerOptions(TypedDict, Generic[TCrawlingContext]): retry_on_blocked: NotRequired[bool] proxy_configuration: NotRequired[ProxyConfiguration] statistics: NotRequired[Statistics[StatisticsState]] + event_manager: NotRequired[EventManager] configure_logging: NotRequired[bool] _context_pipeline: NotRequired[ContextPipeline[TCrawlingContext]] _additional_context_managers: NotRequired[Sequence[AsyncContextManager]] @@ -111,6 +113,7 @@ def __init__( retry_on_blocked: bool = True, proxy_configuration: ProxyConfiguration | None = None, statistics: Statistics | None = None, + event_manager: EventManager | None = None, configure_logging: bool = True, _context_pipeline: ContextPipeline[TCrawlingContext] | None = None, _additional_context_managers: Sequence[AsyncContextManager] | None = None, @@ -138,6 +141,7 @@ def __init__( retry_on_blocked: If set to True, the crawler will try to automatically bypass any detected bot protection proxy_configuration: A HTTP proxy configuration to be used for making requests statistics: A preconfigured `Statistics` instance if you wish to use non-default configuration + event_manager: A custom `EventManager` instance if you wish to use a non-default one configure_logging: If set to True, the crawler will configure the logging infrastructure _context_pipeline: Allows extending the request lifecycle and modifying the crawling context. This parameter is meant to be used by child classes, not when BasicCrawler is instantiated directly. @@ -164,7 +168,7 @@ def __init__( self._max_session_rotations = max_session_rotations self._request_provider = request_provider - self._configuration = configuration or Configuration.get_global_configuration() + self._configuration = configuration or crawlee.service_container.get_configuration() self._request_handler_timeout = request_handler_timeout self._internal_timeout = ( @@ -175,8 +179,7 @@ def __init__( self._tld_extractor = TLDExtract(cache_dir=tempfile.TemporaryDirectory().name) - self._event_manager = LocalEventManager() # TODO: switch based on configuration - # https://github.com/apify/crawlee-py/issues/83 + self._event_manager = event_manager or crawlee.service_container.get_event_manager() self._snapshotter = Snapshotter(self._event_manager) self._pool = AutoscaledPool( system_status=SystemStatus(self._snapshotter), diff --git a/src/crawlee/configuration.py b/src/crawlee/configuration.py index c81b744bc..6835f1a61 100644 --- a/src/crawlee/configuration.py +++ b/src/crawlee/configuration.py @@ -3,12 +3,13 @@ from __future__ import annotations from datetime import timedelta -from typing import Annotated, ClassVar, cast +from typing import Annotated from pydantic import AliasChoices, Field from pydantic_settings import BaseSettings, SettingsConfigDict from typing_extensions import Self +from crawlee import service_container from crawlee._utils.models import timedelta_ms @@ -22,8 +23,6 @@ class Configuration(BaseSettings): purge_on_start: Whether to purge the storage on start. """ - _default_instance: ClassVar[Self | None] = None - model_config = SettingsConfigDict(populate_by_name=True) internal_timeout: Annotated[timedelta | None, Field(alias='crawlee_internal_timeout')] = None @@ -206,12 +205,14 @@ class Configuration(BaseSettings): ), ] = False - in_cloud: Annotated[bool, Field(alias='crawlee_in_cloud')] = False - @classmethod def get_global_configuration(cls) -> Self: """Retrieve the global instance of the configuration.""" - if Configuration._default_instance is None: - Configuration._default_instance = cls() + global_instance = service_container.get_configuration() + + if not isinstance(global_instance, cls): + raise TypeError( + f'Requested global configuration object of type {cls}, but {global_instance.__class__} was found' + ) - return cast(Self, Configuration._default_instance) + return global_instance diff --git a/src/crawlee/service_container.py b/src/crawlee/service_container.py new file mode 100644 index 000000000..9f464d65a --- /dev/null +++ b/src/crawlee/service_container.py @@ -0,0 +1,97 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from typing_extensions import NotRequired, TypedDict + +from crawlee.configuration import Configuration +from crawlee.events.local_event_manager import LocalEventManager +from crawlee.memory_storage_client.memory_storage_client import MemoryStorageClient + +if TYPE_CHECKING: + from crawlee.base_storage_client.base_storage_client import BaseStorageClient + from crawlee.events.event_manager import EventManager + from crawlee.types import StorageClientType + + +class _Services(TypedDict): + local_storage_client: NotRequired[BaseStorageClient] + cloud_storage_client: NotRequired[BaseStorageClient] + configuration: NotRequired[Configuration] + event_manager: NotRequired[EventManager] + + +_services = _Services() +_default_storage_client_type: StorageClientType = 'local' + + +def get_storage_client(*, client_type: StorageClientType | None = None) -> BaseStorageClient: + """Get the storage client instance for the current environment. + + Args: + client_type: Allows retrieving a specific storage client type, regardless of where we are running + + Returns: + The current storage client instance. + """ + if client_type is None: + client_type = _default_storage_client_type + + if client_type == 'cloud': + if 'cloud_storage_client' not in _services: + raise RuntimeError('Cloud client was not provided.') + return _services['cloud_storage_client'] + + if 'local_storage_client' not in _services: + _services['local_storage_client'] = MemoryStorageClient() + + return _services['local_storage_client'] + + +def set_local_storage_client(local_client: BaseStorageClient) -> None: + """Set the local storage client instance. + + Args: + local_client: The local storage client instance. + """ + _services['local_storage_client'] = local_client + + +def set_cloud_storage_client(cloud_client: BaseStorageClient) -> None: + """Set the cloud storage client instance. + + Args: + cloud_client: The cloud storage client instance. + """ + _services['cloud_storage_client'] = cloud_client + + +def set_default_storage_client_type(client_type: StorageClientType) -> None: + """Set the default storage client type.""" + _default_storage_client_type = client_type + + +def get_configuration() -> Configuration: + """Get the configuration object.""" + if 'configuration' not in _services: + _services['configuration'] = Configuration() + + return _services['configuration'] + + +def set_configuration(configuration: Configuration) -> None: + """Set the configuration object.""" + _services['configuration'] = configuration + + +def get_event_manager() -> EventManager: + """Get the event manager.""" + if 'event_manager' not in _services: + _services['event_manager'] = LocalEventManager() + + return _services['event_manager'] + + +def set_event_manager(event_manager: EventManager) -> None: + """Set the event manager.""" + _services['event_manager'] = event_manager diff --git a/src/crawlee/statistics/statistics.py b/src/crawlee/statistics/statistics.py index 49b8bda2f..f0ae7746a 100644 --- a/src/crawlee/statistics/statistics.py +++ b/src/crawlee/statistics/statistics.py @@ -8,8 +8,8 @@ from typing_extensions import Self, TypeVar +import crawlee.service_container from crawlee._utils.recurring_task import RecurringTask -from crawlee.events import LocalEventManager from crawlee.events.types import Event, EventPersistStateData from crawlee.statistics import FinalStatistics, StatisticsPersistedState, StatisticsState from crawlee.statistics.error_tracker import ErrorTracker @@ -85,7 +85,7 @@ def __init__( self.error_tracker = ErrorTracker() self.error_tracker_retry = ErrorTracker() - self._events = event_manager or LocalEventManager() + self._events = event_manager or crawlee.service_container.get_event_manager() self._requests_in_progress = dict[str, RequestProcessingRecord]() diff --git a/src/crawlee/storage_client_manager.py b/src/crawlee/storage_client_manager.py deleted file mode 100644 index dc7e3a892..000000000 --- a/src/crawlee/storage_client_manager.py +++ /dev/null @@ -1,50 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -from crawlee.memory_storage_client import MemoryStorageClient - -if TYPE_CHECKING: - from crawlee.base_storage_client import BaseStorageClient - - -class StorageClientManager: - """A class for managing storage clients.""" - - _local_client: BaseStorageClient = MemoryStorageClient() - _cloud_client: BaseStorageClient | None = None - - @classmethod - def get_storage_client(cls, *, in_cloud: bool = False) -> BaseStorageClient: - """Get the storage client instance for the current environment. - - Args: - in_cloud: Whether the code is running in the cloud environment. - - Returns: - The current storage client instance. - """ - if in_cloud: - if cls._cloud_client is None: - raise RuntimeError('Running in cloud environment, but cloud client was not provided.') - return cls._cloud_client - - return cls._local_client - - @classmethod - def set_cloud_client(cls, cloud_client: BaseStorageClient) -> None: - """Set the cloud storage client instance. - - Args: - cloud_client: The cloud storage client instance. - """ - cls._cloud_client = cloud_client - - @classmethod - def set_local_client(cls, local_client: BaseStorageClient) -> None: - """Set the local storage client instance. - - Args: - local_client: The local storage client instance. - """ - cls._local_client = local_client diff --git a/src/crawlee/storages/_creation_management.py b/src/crawlee/storages/_creation_management.py index 6a4aa0deb..f6a636afb 100644 --- a/src/crawlee/storages/_creation_management.py +++ b/src/crawlee/storages/_creation_management.py @@ -3,14 +3,15 @@ import asyncio from typing import TYPE_CHECKING, TypeVar +from crawlee import service_container from crawlee.configuration import Configuration from crawlee.memory_storage_client import MemoryStorageClient -from crawlee.storage_client_manager import StorageClientManager from crawlee.storages import Dataset, KeyValueStore, RequestQueue if TYPE_CHECKING: from crawlee.base_storage_client import BaseStorageClient from crawlee.base_storage_client.types import ResourceClient, ResourceCollectionClient + from crawlee.types import StorageClientType TResource = TypeVar('TResource', Dataset, KeyValueStore, RequestQueue) @@ -122,13 +123,14 @@ def _get_default_storage_id(configuration: Configuration, storage_class: type[TR async def open_storage( *, storage_class: type[TResource], + storage_client_type: StorageClientType | None = None, configuration: Configuration | None = None, id: str | None = None, name: str | None = None, ) -> TResource: """Open either a new storage or restore an existing one and return it.""" configuration = configuration or Configuration.get_global_configuration() - storage_client = StorageClientManager.get_storage_client(in_cloud=configuration.in_cloud) + storage_client = service_container.get_storage_client(client_type=storage_client_type) # Try to restore the storage from cache by name if name: @@ -170,12 +172,21 @@ async def open_storage( resource_collection_client = _get_resource_collection_client(storage_class, storage_client) storage_info = await resource_collection_client.get_or_create(name=name) - storage = storage_class( - id=storage_info.id, - name=storage_info.name, - configuration=configuration, - client=storage_client, - ) + if issubclass(storage_class, RequestQueue): + storage = storage_class( + id=storage_info.id, + name=storage_info.name, + configuration=configuration, + client=storage_client, + event_manager=service_container.get_event_manager(), + ) + else: + storage = storage_class( + id=storage_info.id, + name=storage_info.name, + configuration=configuration, + client=storage_client, + ) # Cache the storage by ID and name _add_to_cache_by_id(storage.id, storage) diff --git a/src/crawlee/storages/dataset.py b/src/crawlee/storages/dataset.py index b72ee1303..df3875261 100644 --- a/src/crawlee/storages/dataset.py +++ b/src/crawlee/storages/dataset.py @@ -18,7 +18,7 @@ from crawlee.base_storage_client import BaseStorageClient from crawlee.configuration import Configuration from crawlee.models import DatasetItemsListPage - from crawlee.types import JSONSerializable + from crawlee.types import JSONSerializable, StorageClientType logger = logging.getLogger(__name__) @@ -135,6 +135,7 @@ async def open( id: str | None = None, name: str | None = None, configuration: Configuration | None = None, + storage_client_type: StorageClientType | None = None, ) -> Dataset: from crawlee.storages._creation_management import open_storage @@ -143,6 +144,7 @@ async def open( id=id, name=name, configuration=configuration, + storage_client_type=storage_client_type, ) @override diff --git a/src/crawlee/storages/key_value_store.py b/src/crawlee/storages/key_value_store.py index b63706e31..f9035a388 100644 --- a/src/crawlee/storages/key_value_store.py +++ b/src/crawlee/storages/key_value_store.py @@ -10,6 +10,7 @@ if TYPE_CHECKING: from crawlee.base_storage_client import BaseStorageClient from crawlee.configuration import Configuration + from crawlee.types import StorageClientType T = TypeVar('T') @@ -71,6 +72,7 @@ async def open( id: str | None = None, name: str | None = None, configuration: Configuration | None = None, + storage_client_type: StorageClientType | None = None, ) -> KeyValueStore: from crawlee.storages._creation_management import open_storage @@ -79,6 +81,7 @@ async def open( id=id, name=name, configuration=configuration, + storage_client_type=storage_client_type, ) @override diff --git a/src/crawlee/storages/request_queue.py b/src/crawlee/storages/request_queue.py index 99cb8c112..192c33b7b 100644 --- a/src/crawlee/storages/request_queue.py +++ b/src/crawlee/storages/request_queue.py @@ -13,7 +13,6 @@ from crawlee._utils.lru_cache import LRUCache from crawlee._utils.requests import unique_key_to_request_id from crawlee._utils.wait import wait_for_all_tasks_for_finish -from crawlee.events.event_manager import EventManager from crawlee.events.types import Event from crawlee.models import ( BaseRequestData, @@ -29,6 +28,8 @@ from crawlee.base_storage_client import BaseStorageClient from crawlee.configuration import Configuration + from crawlee.events.event_manager import EventManager + from crawlee.types import StorageClientType __all__ = ['RequestQueue'] @@ -151,6 +152,7 @@ async def open( id: str | None = None, name: str | None = None, configuration: Configuration | None = None, + storage_client_type: StorageClientType | None = None, ) -> RequestQueue: from crawlee.storages._creation_management import open_storage @@ -159,6 +161,7 @@ async def open( id=id, name=name, configuration=configuration, + storage_client_type=storage_client_type, ) await storage._ensure_head_is_non_empty() # noqa: SLF001 - accessing private members from factories is OK diff --git a/src/crawlee/types.py b/src/crawlee/types.py index 004ef047c..8d3ed9c82 100644 --- a/src/crawlee/types.py +++ b/src/crawlee/types.py @@ -35,6 +35,9 @@ class StorageTypes(str, Enum): REQUEST_QUEUE = 'Request queue' +StorageClientType = Literal['cloud', 'local'] + + class AddRequestsKwargs(TypedDict): """Keyword arguments for crawler's `add_requests` method.""" diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 73e996005..c7d456db3 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -4,15 +4,15 @@ from __future__ import annotations import os -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING, Callable, cast import pytest from proxy import Proxy +from crawlee import service_container from crawlee.configuration import Configuration from crawlee.memory_storage_client import MemoryStorageClient from crawlee.proxy_configuration import ProxyInfo -from crawlee.storage_client_manager import StorageClientManager from crawlee.storages import _creation_management if TYPE_CHECKING: @@ -26,12 +26,8 @@ def reset() -> None: # Set the environment variable for the local storage directory to the temporary path monkeypatch.setenv('CRAWLEE_STORAGE_DIR', str(tmp_path)) - # Reset the local and cloud clients in StorageClientManager - StorageClientManager._local_client = MemoryStorageClient() - StorageClientManager._cloud_client = None - - # Remove global configuration instance - it may contain settings adjusted by a previous test - Configuration._default_instance = None + # Reset services in crawlee.service_container + cast(dict, service_container._services).clear() # Clear creation-related caches to ensure no state is carried over between tests monkeypatch.setattr(_creation_management, '_cache_dataset_by_id', {}) diff --git a/tests/unit/memory_storage_client/test_memory_storage_e2e.py b/tests/unit/memory_storage_client/test_memory_storage_e2e.py index 19e71b594..3779e8ebf 100644 --- a/tests/unit/memory_storage_client/test_memory_storage_e2e.py +++ b/tests/unit/memory_storage_client/test_memory_storage_e2e.py @@ -5,8 +5,8 @@ import pytest +from crawlee import service_container from crawlee.models import Request -from crawlee.storage_client_manager import StorageClientManager from crawlee.storages.key_value_store import KeyValueStore from crawlee.storages.request_queue import RequestQueue @@ -23,7 +23,7 @@ async def test_actor_memory_storage_client_key_value_store_e2e( # Configure purging env var monkeypatch.setenv('CRAWLEE_PURGE_ON_START', f'{int(purge_on_start)}') # Store old storage client so we have the object reference for comparison - old_client = StorageClientManager.get_storage_client() + old_client = service_container.get_storage_client() old_default_kvs = await KeyValueStore.open() old_non_default_kvs = await KeyValueStore.open(name='non-default') @@ -36,7 +36,7 @@ async def test_actor_memory_storage_client_key_value_store_e2e( reset_globals() # Check if we're using a different memory storage instance - assert old_client is not StorageClientManager.get_storage_client() + assert old_client is not service_container.get_storage_client() default_kvs = await KeyValueStore.open() assert default_kvs is not old_default_kvs non_default_kvs = await KeyValueStore.open(name='non-default') diff --git a/tests/unit/test_storage_client_manager.py b/tests/unit/test_storage_client_manager.py deleted file mode 100644 index b4fdd279d..000000000 --- a/tests/unit/test_storage_client_manager.py +++ /dev/null @@ -1,33 +0,0 @@ -from __future__ import annotations - -from unittest.mock import Mock - -import pytest - -from crawlee.base_storage_client import BaseStorageClient -from crawlee.memory_storage_client import MemoryStorageClient -from crawlee.storage_client_manager import StorageClientManager - - -def test_returns_memory_storage_client_as_default() -> None: - storage_client = StorageClientManager.get_storage_client() - assert isinstance(storage_client, MemoryStorageClient), 'Should return the memory storage client by default' - - -def test_returns_provided_local_client_for_non_cloud_environment() -> None: - local_client = Mock(spec=BaseStorageClient) - StorageClientManager.set_local_client(local_client) - storage_client = StorageClientManager.get_storage_client() - assert storage_client == local_client, 'Should return the local client when not in cloud' - - -def test_returns_provided_cloud_client_for_cloud_environment() -> None: - cloud_client = Mock(spec=BaseStorageClient) - StorageClientManager.set_cloud_client(cloud_client) - storage_client = StorageClientManager.get_storage_client(in_cloud=True) - assert storage_client == cloud_client, 'Should return the cloud client when in cloud' - - -def test_raises_error_when_no_cloud_client_provided() -> None: - with pytest.raises(RuntimeError, match='cloud client was not provided'): - StorageClientManager.get_storage_client(in_cloud=True) From 0284591490e336230aa5c26672ac92ca213f695e Mon Sep 17 00:00:00 2001 From: Jan Buchar Date: Thu, 15 Aug 2024 09:34:51 +0200 Subject: [PATCH 03/10] Placeholder implementation of locking --- .../base_request_queue_client.py | 15 ------------ .../request_queue_client.py | 23 ++++++++----------- src/crawlee/models.py | 11 --------- 3 files changed, 10 insertions(+), 39 deletions(-) diff --git a/src/crawlee/base_storage_client/base_request_queue_client.py b/src/crawlee/base_storage_client/base_request_queue_client.py index 309b65e3a..494096b16 100644 --- a/src/crawlee/base_storage_client/base_request_queue_client.py +++ b/src/crawlee/base_storage_client/base_request_queue_client.py @@ -11,7 +11,6 @@ ProcessedRequest, ProlongRequestLockResponse, Request, - RequestListResponse, RequestQueueHead, RequestQueueHeadWithLocks, RequestQueueMetadata, @@ -185,17 +184,3 @@ async def batch_delete_requests(self, requests: list[Request]) -> BatchRequestsO Args: requests: The requests to delete from the queue. """ - - @abstractmethod - async def list_requests( - self, - *, - limit: int | None = None, - exclusive_start_id: str | None = None, - ) -> RequestListResponse: - """List requests from the queue. - - Args: - limit: How many requests to retrieve. - exclusive_start_id: All requests up to this one (including) are skipped from the result. - """ diff --git a/src/crawlee/memory_storage_client/request_queue_client.py b/src/crawlee/memory_storage_client/request_queue_client.py index 0b8817777..1acbfbb7b 100644 --- a/src/crawlee/memory_storage_client/request_queue_client.py +++ b/src/crawlee/memory_storage_client/request_queue_client.py @@ -30,7 +30,6 @@ ProcessedRequest, ProlongRequestLockResponse, Request, - RequestListResponse, RequestQueueHead, RequestQueueHeadWithLocks, RequestQueueMetadata, @@ -215,7 +214,14 @@ async def list_head(self, *, limit: int | None = None) -> RequestQueueHead: @override async def list_and_lock_head(self, *, lock_secs: int, limit: int | None = None) -> RequestQueueHeadWithLocks: - raise NotImplementedError('This method is not supported in memory storage.') + result = await self.list_head(limit=limit) + return RequestQueueHeadWithLocks( + lock_secs=lock_secs, + limit=result.limit, + had_multiple_clients=result.had_multiple_clients, + queue_modified_at=result.queue_modified_at, + items=result.items, + ) @override async def add_request( @@ -380,7 +386,7 @@ async def prolong_request_lock( forefront: bool = False, lock_secs: int, ) -> ProlongRequestLockResponse: - raise NotImplementedError('This method is not supported in memory storage.') + return ProlongRequestLockResponse(lock_expires_at=datetime.now(timezone.utc)) @override async def delete_request_lock( @@ -389,7 +395,7 @@ async def delete_request_lock( *, forefront: bool = False, ) -> None: - raise NotImplementedError('This method is not supported in memory storage.') + return None @override async def batch_add_requests( @@ -431,15 +437,6 @@ async def batch_add_requests( async def batch_delete_requests(self, requests: list[Request]) -> BatchRequestsOperationResponse: raise NotImplementedError('This method is not supported in memory storage.') - @override - async def list_requests( - self, - *, - limit: int | None = None, - exclusive_start_id: str | None = None, - ) -> RequestListResponse: - raise NotImplementedError('This method is not supported in memory storage.') - async def update_timestamps(self, *, has_been_modified: bool) -> None: """Update the timestamps of the request queue.""" self._accessed_at = datetime.now(timezone.utc) diff --git a/src/crawlee/models.py b/src/crawlee/models.py index 4528a5aee..1fccd8338 100644 --- a/src/crawlee/models.py +++ b/src/crawlee/models.py @@ -350,7 +350,6 @@ class RequestQueueHeadWithLocks(RequestQueueHead): """Model for request queue head with locks.""" lock_secs: Annotated[int, Field(alias='lockSecs')] - items: Annotated[list[Request], Field(alias='items', default_factory=list)] class BaseListPage(BaseModel): @@ -449,13 +448,3 @@ class BatchRequestsOperationResponse(BaseModel): processed_requests: Annotated[list[ProcessedRequest], Field(alias='processedRequests')] unprocessed_requests: Annotated[list[UnprocessedRequest], Field(alias='unprocessedRequests')] - - -class RequestListResponse(BaseModel): - """Response to a request list call.""" - - model_config = ConfigDict(populate_by_name=True) - - limit: Annotated[int, Field()] - exclusive_start_key: Annotated[str | None, Field(alias='exclusiveStartId')] - items: Annotated[list[Request], Field()] From ad3fa87f6188b5a7e5f1e5148d62d95a4939fe28 Mon Sep 17 00:00:00 2001 From: Jan Buchar Date: Thu, 15 Aug 2024 09:42:43 +0200 Subject: [PATCH 04/10] Fix request dequeueing order --- src/crawlee/storages/request_queue.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/crawlee/storages/request_queue.py b/src/crawlee/storages/request_queue.py index 192c33b7b..56ee38cfa 100644 --- a/src/crawlee/storages/request_queue.py +++ b/src/crawlee/storages/request_queue.py @@ -156,7 +156,7 @@ async def open( ) -> RequestQueue: from crawlee.storages._creation_management import open_storage - storage = await open_storage( + return await open_storage( storage_class=cls, id=id, name=name, @@ -164,10 +164,6 @@ async def open( storage_client_type=storage_client_type, ) - await storage._ensure_head_is_non_empty() # noqa: SLF001 - accessing private members from factories is OK - - return storage - @override async def drop(self, *, timeout: timedelta | None = None) -> None: from crawlee.storages._creation_management import remove_storage_from_cache From 825a1abfb3f0fd8c36131f577a83463e70b52879 Mon Sep 17 00:00:00 2001 From: Jan Buchar Date: Thu, 15 Aug 2024 13:56:10 +0200 Subject: [PATCH 05/10] Stuff --- src/crawlee/configuration.py | 6 +++++- src/crawlee/service_container.py | 11 +++++++++-- src/crawlee/storages/_creation_management.py | 5 ++--- src/crawlee/storages/dataset.py | 6 +++--- src/crawlee/storages/key_value_store.py | 5 ++--- src/crawlee/storages/request_queue.py | 5 ++--- src/crawlee/types.py | 3 --- 7 files changed, 23 insertions(+), 18 deletions(-) diff --git a/src/crawlee/configuration.py b/src/crawlee/configuration.py index 6835f1a61..ad6ea0204 100644 --- a/src/crawlee/configuration.py +++ b/src/crawlee/configuration.py @@ -9,7 +9,6 @@ from pydantic_settings import BaseSettings, SettingsConfigDict from typing_extensions import Self -from crawlee import service_container from crawlee._utils.models import timedelta_ms @@ -208,6 +207,11 @@ class Configuration(BaseSettings): @classmethod def get_global_configuration(cls) -> Self: """Retrieve the global instance of the configuration.""" + from crawlee import service_container + + if service_container.get_configuration_if_set() is None: + service_container.set_configuration(cls()) + global_instance = service_container.get_configuration() if not isinstance(global_instance, cls): diff --git a/src/crawlee/service_container.py b/src/crawlee/service_container.py index 9f464d65a..b039fab4a 100644 --- a/src/crawlee/service_container.py +++ b/src/crawlee/service_container.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal from typing_extensions import NotRequired, TypedDict @@ -11,7 +11,9 @@ if TYPE_CHECKING: from crawlee.base_storage_client.base_storage_client import BaseStorageClient from crawlee.events.event_manager import EventManager - from crawlee.types import StorageClientType + + +StorageClientType = Literal['cloud', 'local'] class _Services(TypedDict): @@ -79,6 +81,11 @@ def get_configuration() -> Configuration: return _services['configuration'] +def get_configuration_if_set() -> Configuration | None: + """Get the configuration object, or None if it hasn't been set yet.""" + return _services.get('configuration') + + def set_configuration(configuration: Configuration) -> None: """Set the configuration object.""" _services['configuration'] = configuration diff --git a/src/crawlee/storages/_creation_management.py b/src/crawlee/storages/_creation_management.py index f6a636afb..b50932286 100644 --- a/src/crawlee/storages/_creation_management.py +++ b/src/crawlee/storages/_creation_management.py @@ -11,7 +11,6 @@ if TYPE_CHECKING: from crawlee.base_storage_client import BaseStorageClient from crawlee.base_storage_client.types import ResourceClient, ResourceCollectionClient - from crawlee.types import StorageClientType TResource = TypeVar('TResource', Dataset, KeyValueStore, RequestQueue) @@ -123,14 +122,14 @@ def _get_default_storage_id(configuration: Configuration, storage_class: type[TR async def open_storage( *, storage_class: type[TResource], - storage_client_type: StorageClientType | None = None, + storage_client: BaseStorageClient | None = None, configuration: Configuration | None = None, id: str | None = None, name: str | None = None, ) -> TResource: """Open either a new storage or restore an existing one and return it.""" configuration = configuration or Configuration.get_global_configuration() - storage_client = service_container.get_storage_client(client_type=storage_client_type) + storage_client = storage_client or service_container.get_storage_client() # Try to restore the storage from cache by name if name: diff --git a/src/crawlee/storages/dataset.py b/src/crawlee/storages/dataset.py index df3875261..fe3a045f6 100644 --- a/src/crawlee/storages/dataset.py +++ b/src/crawlee/storages/dataset.py @@ -18,7 +18,7 @@ from crawlee.base_storage_client import BaseStorageClient from crawlee.configuration import Configuration from crawlee.models import DatasetItemsListPage - from crawlee.types import JSONSerializable, StorageClientType + from crawlee.types import JSONSerializable logger = logging.getLogger(__name__) @@ -135,7 +135,7 @@ async def open( id: str | None = None, name: str | None = None, configuration: Configuration | None = None, - storage_client_type: StorageClientType | None = None, + storage_client: BaseStorageClient | None = None, ) -> Dataset: from crawlee.storages._creation_management import open_storage @@ -144,7 +144,7 @@ async def open( id=id, name=name, configuration=configuration, - storage_client_type=storage_client_type, + storage_client=storage_client, ) @override diff --git a/src/crawlee/storages/key_value_store.py b/src/crawlee/storages/key_value_store.py index f9035a388..afdfb5729 100644 --- a/src/crawlee/storages/key_value_store.py +++ b/src/crawlee/storages/key_value_store.py @@ -10,7 +10,6 @@ if TYPE_CHECKING: from crawlee.base_storage_client import BaseStorageClient from crawlee.configuration import Configuration - from crawlee.types import StorageClientType T = TypeVar('T') @@ -72,7 +71,7 @@ async def open( id: str | None = None, name: str | None = None, configuration: Configuration | None = None, - storage_client_type: StorageClientType | None = None, + storage_client: BaseStorageClient | None = None, ) -> KeyValueStore: from crawlee.storages._creation_management import open_storage @@ -81,7 +80,7 @@ async def open( id=id, name=name, configuration=configuration, - storage_client_type=storage_client_type, + storage_client=storage_client, ) @override diff --git a/src/crawlee/storages/request_queue.py b/src/crawlee/storages/request_queue.py index 56ee38cfa..38f5431fa 100644 --- a/src/crawlee/storages/request_queue.py +++ b/src/crawlee/storages/request_queue.py @@ -29,7 +29,6 @@ from crawlee.base_storage_client import BaseStorageClient from crawlee.configuration import Configuration from crawlee.events.event_manager import EventManager - from crawlee.types import StorageClientType __all__ = ['RequestQueue'] @@ -152,7 +151,7 @@ async def open( id: str | None = None, name: str | None = None, configuration: Configuration | None = None, - storage_client_type: StorageClientType | None = None, + storage_client: BaseStorageClient | None = None, ) -> RequestQueue: from crawlee.storages._creation_management import open_storage @@ -161,7 +160,7 @@ async def open( id=id, name=name, configuration=configuration, - storage_client_type=storage_client_type, + storage_client=storage_client, ) @override diff --git a/src/crawlee/types.py b/src/crawlee/types.py index 8d3ed9c82..004ef047c 100644 --- a/src/crawlee/types.py +++ b/src/crawlee/types.py @@ -35,9 +35,6 @@ class StorageTypes(str, Enum): REQUEST_QUEUE = 'Request queue' -StorageClientType = Literal['cloud', 'local'] - - class AddRequestsKwargs(TypedDict): """Keyword arguments for crawler's `add_requests` method.""" From 451f6a2feab986f77a52299f816c082f830266be Mon Sep 17 00:00:00 2001 From: Jan Buchar Date: Thu, 15 Aug 2024 15:16:19 +0200 Subject: [PATCH 06/10] Throw when services are overwritten --- src/crawlee/service_container.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/src/crawlee/service_container.py b/src/crawlee/service_container.py index b039fab4a..de07f8063 100644 --- a/src/crawlee/service_container.py +++ b/src/crawlee/service_container.py @@ -27,6 +27,15 @@ class _Services(TypedDict): _default_storage_client_type: StorageClientType = 'local' +class ServiceConflictError(RuntimeError): + """Thrown when a service is getting reconfigured.""" + + def __init__(self, service_name: str, new_value: object, old_value: object) -> None: + super().__init__( + f"Service '{service_name}' was already set (existing value is '{old_value}', new value is '{new_value}')." + ) + + def get_storage_client(*, client_type: StorageClientType | None = None) -> BaseStorageClient: """Get the storage client instance for the current environment. @@ -56,6 +65,9 @@ def set_local_storage_client(local_client: BaseStorageClient) -> None: Args: local_client: The local storage client instance. """ + if (existing_service := _services.get('local_storage_client')) and existing_service is not local_client: + raise ServiceConflictError('local_storage_client', local_client, existing_service) + _services['local_storage_client'] = local_client @@ -65,6 +77,9 @@ def set_cloud_storage_client(cloud_client: BaseStorageClient) -> None: Args: cloud_client: The cloud storage client instance. """ + if (existing_service := _services.get('cloud_storage_client')) and existing_service is not cloud_client: + raise ServiceConflictError('cloud_storage_client', cloud_client, existing_service) + _services['cloud_storage_client'] = cloud_client @@ -88,6 +103,9 @@ def get_configuration_if_set() -> Configuration | None: def set_configuration(configuration: Configuration) -> None: """Set the configuration object.""" + if (existing_service := _services.get('configuration')) and existing_service is not configuration: + raise ServiceConflictError('configuration', configuration, existing_service) + _services['configuration'] = configuration @@ -101,4 +119,7 @@ def get_event_manager() -> EventManager: def set_event_manager(event_manager: EventManager) -> None: """Set the event manager.""" + if (existing_service := _services.get('event_manager')) and existing_service is not event_manager: + raise ServiceConflictError('event_manager', event_manager, existing_service) + _services['event_manager'] = event_manager From 5cefd2d3864b3d79a25bb7cbb74f8db675276f58 Mon Sep 17 00:00:00 2001 From: Jan Buchar Date: Thu, 22 Aug 2024 10:42:45 +0200 Subject: [PATCH 07/10] Update src/crawlee/service_container.py Co-authored-by: Vlada Dusek --- src/crawlee/service_container.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/crawlee/service_container.py b/src/crawlee/service_container.py index de07f8063..ce2e236f2 100644 --- a/src/crawlee/service_container.py +++ b/src/crawlee/service_container.py @@ -40,7 +40,7 @@ def get_storage_client(*, client_type: StorageClientType | None = None) -> BaseS """Get the storage client instance for the current environment. Args: - client_type: Allows retrieving a specific storage client type, regardless of where we are running + client_type: Allows retrieving a specific storage client type, regardless of where we are running. Returns: The current storage client instance. From 3766f97919daef258ec8fc40ba3aeefdf8676fa5 Mon Sep 17 00:00:00 2001 From: Jan Buchar Date: Thu, 22 Aug 2024 10:42:55 +0200 Subject: [PATCH 08/10] Update src/crawlee/storages/request_queue.py Co-authored-by: Vlada Dusek --- src/crawlee/storages/request_queue.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/crawlee/storages/request_queue.py b/src/crawlee/storages/request_queue.py index 38f5431fa..8abdc7d2f 100644 --- a/src/crawlee/storages/request_queue.py +++ b/src/crawlee/storages/request_queue.py @@ -204,7 +204,7 @@ async def add_request( use_extended_unique_key: Determines whether to use an extended unique key, incorporating the request's method and payload into the unique key computation. - Returns: Information about the processed request + Returns: Information about the processed request. """ request = self._transform_request(request) self._last_activity = datetime.now(timezone.utc) From 10661314e97147f5580a73ea1d350ef2099c7006 Mon Sep 17 00:00:00 2001 From: Jan Buchar Date: Thu, 22 Aug 2024 10:47:23 +0200 Subject: [PATCH 09/10] Change import --- src/crawlee/basic_crawler/basic_crawler.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/crawlee/basic_crawler/basic_crawler.py b/src/crawlee/basic_crawler/basic_crawler.py index 53de93a76..8beda470d 100644 --- a/src/crawlee/basic_crawler/basic_crawler.py +++ b/src/crawlee/basic_crawler/basic_crawler.py @@ -18,8 +18,7 @@ from tldextract import TLDExtract from typing_extensions import NotRequired, TypedDict, TypeVar, Unpack, assert_never -import crawlee.service_container -from crawlee import Glob +from crawlee import Glob, service_container from crawlee._utils.urls import convert_to_absolute_url, is_url_absolute from crawlee._utils.wait import wait_for from crawlee.autoscaling import AutoscaledPool, ConcurrencySettings @@ -168,7 +167,7 @@ def __init__( self._max_session_rotations = max_session_rotations self._request_provider = request_provider - self._configuration = configuration or crawlee.service_container.get_configuration() + self._configuration = configuration or service_container.get_configuration() self._request_handler_timeout = request_handler_timeout self._internal_timeout = ( @@ -179,7 +178,7 @@ def __init__( self._tld_extractor = TLDExtract(cache_dir=tempfile.TemporaryDirectory().name) - self._event_manager = event_manager or crawlee.service_container.get_event_manager() + self._event_manager = event_manager or service_container.get_event_manager() self._snapshotter = Snapshotter(self._event_manager) self._pool = AutoscaledPool( system_status=SystemStatus(self._snapshotter), From bfac67921b937451c91205e5ab19c62462626350 Mon Sep 17 00:00:00 2001 From: Jan Buchar Date: Thu, 22 Aug 2024 11:42:31 +0200 Subject: [PATCH 10/10] Add service container tests --- src/crawlee/service_container.py | 1 + tests/unit/test_service_container.py | 89 ++++++++++++++++++++++++++++ 2 files changed, 90 insertions(+) create mode 100644 tests/unit/test_service_container.py diff --git a/src/crawlee/service_container.py b/src/crawlee/service_container.py index ce2e236f2..3eaee5f45 100644 --- a/src/crawlee/service_container.py +++ b/src/crawlee/service_container.py @@ -85,6 +85,7 @@ def set_cloud_storage_client(cloud_client: BaseStorageClient) -> None: def set_default_storage_client_type(client_type: StorageClientType) -> None: """Set the default storage client type.""" + global _default_storage_client_type # noqa: PLW0603 _default_storage_client_type = client_type diff --git a/tests/unit/test_service_container.py b/tests/unit/test_service_container.py new file mode 100644 index 000000000..1646dedf9 --- /dev/null +++ b/tests/unit/test_service_container.py @@ -0,0 +1,89 @@ +from unittest.mock import Mock + +import pytest + +from crawlee import service_container +from crawlee.configuration import Configuration +from crawlee.events.local_event_manager import LocalEventManager +from crawlee.memory_storage_client.memory_storage_client import MemoryStorageClient + + +async def test_get_event_manager() -> None: + event_manager = service_container.get_event_manager() + assert isinstance(event_manager, LocalEventManager) + + +async def test_set_event_manager() -> None: + event_manager = Mock() + service_container.set_event_manager(event_manager) + assert service_container.get_event_manager() is event_manager + + +async def test_overwrite_event_manager() -> None: + event_manager = Mock() + service_container.set_event_manager(event_manager) + service_container.set_event_manager(event_manager) + + with pytest.raises(service_container.ServiceConflictError): + service_container.set_event_manager(Mock()) + + +async def test_get_configuration() -> None: + configuration = service_container.get_configuration() + assert isinstance(configuration, Configuration) + + +async def test_set_configuration() -> None: + configuration = Mock() + service_container.set_configuration(configuration) + assert service_container.get_configuration() is configuration + + +async def test_overwrite_configuration() -> None: + configuration = Mock() + service_container.set_configuration(configuration) + service_container.set_configuration(configuration) + + with pytest.raises(service_container.ServiceConflictError): + service_container.set_configuration(Mock()) + + +async def test_get_storage_client() -> None: + storage_client = service_container.get_storage_client() + assert isinstance(storage_client, MemoryStorageClient) + + with pytest.raises(RuntimeError): + service_container.get_storage_client(client_type='cloud') + + service_container.set_default_storage_client_type('cloud') + + with pytest.raises(RuntimeError): + service_container.get_storage_client() + + storage_client = service_container.get_storage_client(client_type='local') + assert isinstance(storage_client, MemoryStorageClient) + + cloud_client = Mock() + service_container.set_cloud_storage_client(cloud_client) + assert service_container.get_storage_client(client_type='cloud') is cloud_client + assert service_container.get_storage_client() is cloud_client + + +async def test_reset_local_storage_client() -> None: + storage_client = Mock() + + service_container.set_local_storage_client(storage_client) + service_container.set_local_storage_client(storage_client) + + with pytest.raises(service_container.ServiceConflictError): + service_container.set_local_storage_client(Mock()) + + +async def test_reset_cloud_storage_client() -> None: + storage_client = Mock() + + service_container.set_cloud_storage_client(storage_client) + service_container.set_cloud_storage_client(storage_client) + + with pytest.raises(service_container.ServiceConflictError): + service_container.set_cloud_storage_client(Mock())