diff --git a/elm/version.py b/elm/version.py index 806a233f..f9ff74d6 100644 --- a/elm/version.py +++ b/elm/version.py @@ -2,4 +2,4 @@ ELM version number """ -__version__ = "0.0.23" +__version__ = "0.0.24" diff --git a/elm/web/search/google.py b/elm/web/search/google.py index 57ef7689..77c99e61 100644 --- a/elm/web/search/google.py +++ b/elm/web/search/google.py @@ -9,7 +9,7 @@ from contextlib import asynccontextmanager from camoufox.async_api import AsyncCamoufox -from apiclient.discovery import build +from googleapiclient.discovery import build from playwright.async_api import TimeoutError as PlaywrightTimeoutError from elm.web.search.base import (PlaywrightSearchEngineLinkSearch, @@ -244,6 +244,22 @@ class APISerperSearch(APISearchEngineLinkSearch): API_KEY_VAR = "SERPER_API_KEY" """Environment variable that should contain the Google Serper API key""" + def __init__(self, api_key=None, verify=False): + """ + + Parameters + ---------- + api_key : str, optional + API key for serper search API. If ``None``, will look up the + API key using the ``"SERPER_API_KEY"`` environment variable. + By default, ``None``. + verify : bool, default=False + Option to use SSL verification when making request to API + endpoint. By default, ``False``. + """ + super().__init__(api_key=api_key) + self.verify = verify + async def _search(self, query, num_results=10): """Search web for links related to a query""" @@ -252,7 +268,7 @@ async def _search(self, query, num_results=10): 'Content-Type': 'application/json'} response = requests.request("POST", self._URL, headers=headers, - data=payload) + data=payload, verify=self.verify) results = json.loads(response.text).get('organic', {}) return list(filter(None, (result.get("link", "").replace("+", "%20") for result in results))) diff --git a/elm/web/search/run.py b/elm/web/search/run.py index 438c4d4e..1fc95cef 100644 --- a/elm/web/search/run.py +++ b/elm/web/search/run.py @@ -56,6 +56,7 @@ async def web_search_links_as_docs(queries, search_engines=_DEFAULT_SE, num_urls=None, ignore_url_parts=None, search_semaphore=None, browser_semaphore=None, task_name=None, + use_fallback_per_query=True, on_search_complete_hook=None, **kwargs): """Retrieve top ``N`` search results as document instances @@ -75,7 +76,9 @@ async def web_search_links_as_docs(queries, search_engines=_DEFAULT_SE, is used and so on. If all web searches fail, an empty list is returned. See :obj:`~elm.web.search.run.SEARCH_ENGINE_OPTIONS` for supported search engine options. - By default, ``("PlaywrightGoogleLinkSearch", )``. + By default, ``("PlaywrightGoogleLinkSearch", + "PlaywrightDuckDuckGoLinkSearch", + "DuxDistributedGlobalSearch")``. num_urls : int, optional Number of unique top Google search result to return as docs. The google search results from all queries are interleaved and the @@ -101,6 +104,13 @@ async def web_search_links_as_docs(queries, search_engines=_DEFAULT_SE, task_name : str, optional Optional task name to use in :func:`asyncio.create_task`. By default, ``None``. + use_fallback_per_query : bool, default=True + Option to use the fallback list of search engines on a per-query + basis. This means if a single query fails with one search + engine, the fallback search engines will be attempted for that + query. If this input is ``False``, the fallback search engines + are only used if *all* search queries fail for a single search + engine. By default, ``True``. on_search_complete_hook : callable, optional If provided, this async callable will be called after the search engine links have been retrieved. A single argument will be @@ -145,11 +155,13 @@ async def web_search_links_as_docs(queries, search_engines=_DEFAULT_SE, # backward-compatibility search_semaphore = browser_semaphore + fpq = use_fallback_per_query urls = await search_with_fallback(queries, search_engines=search_engines, num_urls=num_urls, ignore_url_parts=ignore_url_parts, browser_semaphore=search_semaphore, - task_name=task_name, **kwargs) + task_name=task_name, + use_fallback_per_query=fpq, **kwargs) if on_search_complete_hook is not None: await on_search_complete_hook(urls) @@ -161,7 +173,7 @@ async def web_search_links_as_docs(queries, search_engines=_DEFAULT_SE, async def search_with_fallback(queries, search_engines=_DEFAULT_SE, num_urls=None, ignore_url_parts=None, browser_semaphore=None, task_name=None, - **kwargs): + use_fallback_per_query=True, **kwargs): """Retrieve search query URLs using multiple search engines if needed Parameters @@ -198,6 +210,13 @@ async def search_with_fallback(queries, search_engines=_DEFAULT_SE, task_name : str, optional Optional task name to use in :func:`asyncio.create_task`. By default, ``None``. + use_fallback_per_query : bool, default=True + Option to use the fallback list of search engines on a per-query + basis. This means if a single query fails with one search + engine, the fallback search engines will be attempted for that + query. If this input is ``False``, the fallback search engines + are only used if *all* search queries fail for a single search + engine. By default, ``True``. **kwargs Keyword-argument pairs to initialize search engines. This input can include and any/all of the following keywords: @@ -241,13 +260,20 @@ async def search_with_fallback(queries, search_engines=_DEFAULT_SE, logger.error(msg) raise ELMInputError(msg) - for se_name in search_engines: - logger.debug("Searching web using %r", se_name) - urls = await _single_se_search(se_name, queries, num_urls, - ignore_url_parts, browser_semaphore, - task_name, kwargs) + if use_fallback_per_query: + urls = await _multi_se_search(search_engines, queries, num_urls, + ignore_url_parts, browser_semaphore, + task_name, kwargs) if urls: return urls + else: + for se_name in search_engines: + logger.debug("Searching web using %r", se_name) + urls = await _single_se_search(se_name, queries, num_urls, + ignore_url_parts, browser_semaphore, + task_name, kwargs) + if urls: + return urls logger.warning("No web results found using %d search engines: %r", len(search_engines), search_engines) @@ -293,17 +319,44 @@ async def load_docs(urls, browser_semaphore=None, **kwargs): async def _single_se_search(se_name, queries, num_urls, ignore_url_parts, browser_sem, task_name, kwargs): """Search for links using a single search engine""" - if se_name not in SEARCH_ENGINE_OPTIONS: - msg = (f"'se_name' must be one of: {list(SEARCH_ENGINE_OPTIONS)}\n" - f"Got {se_name=}") - logger.error(msg) - raise ELMKeyError(msg) - + _validate_se_name(se_name) links = await _run_search(se_name, queries, browser_sem, task_name, kwargs) return _down_select_urls(links, num_urls=num_urls, ignore_url_parts=ignore_url_parts) +async def _multi_se_search(search_engines, queries, num_urls, + ignore_url_parts, browser_sem, task_name, kwargs): + """Search for links using one or more search engines as fallback""" + outputs = {q: None for q in queries} + remaining_queries = list(queries) + for se_name in search_engines: + _validate_se_name(se_name) + + logger.debug("Searching web using %r", se_name) + links = await _run_search(se_name, remaining_queries, browser_sem, + task_name, kwargs) + logger.trace("Links: %r", links) + + failed_queries = [] + for q, se_result in zip(remaining_queries, links): + if not se_result or not se_result[0]: + failed_queries.append(q) + continue + outputs[q] = se_result + + remaining_queries = failed_queries + logger.trace("Remaining queries to search: %r", remaining_queries) + + if not remaining_queries: + break + + links = [link or [[]] for link in outputs.values()] + + return _down_select_urls(links, num_urls=num_urls, + ignore_url_parts=ignore_url_parts) + + async def _run_search(se_name, queries, browser_sem, task_name, kwargs): """Run a search for multiple queries on a single search engine""" searchers = [asyncio.create_task(_single_query_search(se_name, query, @@ -383,3 +436,12 @@ def _as_set(user_input): if isinstance(user_input, str): user_input = {user_input} return set(user_input or []) + + +def _validate_se_name(se_name): + """Validate user search engine name input""" + if se_name not in SEARCH_ENGINE_OPTIONS: + msg = (f"'se_name' must be one of: {list(SEARCH_ENGINE_OPTIONS)}\n" + f"Got {se_name=}") + logger.error(msg) + raise ELMKeyError(msg) diff --git a/elm/web/search/tavily.py b/elm/web/search/tavily.py index 0fef8cdd..f1a4b2cb 100644 --- a/elm/web/search/tavily.py +++ b/elm/web/search/tavily.py @@ -1,8 +1,12 @@ # -*- coding: utf-8 -*- """ELM Web Scraping - Tavily API search""" +import json import logging +import requests from tavily import TavilyClient +from tavily.errors import (UsageLimitExceededError, InvalidAPIKeyError, + BadRequestError, ForbiddenError) from elm.web.search.base import APISearchEngineLinkSearch @@ -10,6 +14,108 @@ logger = logging.getLogger(__name__) +class _PatchedTavilyClient(TavilyClient): + """Patch `TavilyClient` to accept verify keyword""" + + def __init__(self, api_key=None, proxies=None, verify=False): + """ + + Parameters + ---------- + api_key : str, optional + API key for search engine. If ``None``, will look up the API + key using the ``"TAVILY_API_KEY"`` environment variable. + By default, ``None``. + verify : bool, default=False + Option to use SSL verification when making request to API + endpoint. By default, ``False``. + """ + super().__init__(api_key=api_key, proxies=proxies) + self.verify = verify + + def _search(self, query, search_depth="basic", topic="general", + time_range=None, days=7, max_results=5, include_domains=None, + exclude_domains=None, include_answer=False, + include_raw_content=False, include_images=False, timeout=60, + **kwargs): + """Internal search method to send the request to the API""" + + data = {"query": query, "search_depth": search_depth, "topic": topic, + "time_range": time_range, "days": days, + "include_answer": include_answer, + "include_raw_content": include_raw_content, + "max_results": max_results, "include_domains": include_domains, + "exclude_domains": exclude_domains, + "include_images": include_images} + + if kwargs: + data.update(kwargs) + + timeout = min(timeout, 120) + + response = requests.post(self.base_url + "/search", + data=json.dumps(data), headers=self.headers, + timeout=timeout, proxies=self.proxies, + verify=self.verify) + + if response.status_code == 200: + return response.json() + else: + detail = "" + try: + detail = response.json().get("detail", {}).get("error", None) + except Exception: + pass + + if response.status_code == 429: + raise UsageLimitExceededError(detail) + elif response.status_code in [403,432,433]: + raise ForbiddenError(detail) + elif response.status_code == 401: + raise InvalidAPIKeyError(detail) + elif response.status_code == 400: + raise BadRequestError(detail) + else: + raise response.raise_for_status() + + def _extract(self, urls, include_images=False, extract_depth="basic", + timeout=60, **kwargs): + """ + Internal extract method to send the request to the API. + """ + data = {"urls": urls, "include_images": include_images, + "extract_depth": extract_depth} + if kwargs: + data.update(kwargs) + + timeout = min(timeout, 120) + + response = requests.post(self.base_url + "/extract", + data=json.dumps(data), headers=self.headers, + timeout=timeout, proxies=self.proxies, + verify=self.verify) + + if response.status_code == 200: + return response.json() + else: + detail = "" + try: + detail = response.json().get("detail", {}).get("error", None) + except Exception: + pass + + if response.status_code == 429: + raise UsageLimitExceededError(detail) + elif response.status_code in [403,432,433]: + raise ForbiddenError(detail) + elif response.status_code == 401: + raise InvalidAPIKeyError(detail) + elif response.status_code == 400: + raise BadRequestError(detail) + else: + raise response.raise_for_status() + + class APITavilySearch(APISearchEngineLinkSearch): """Search the web for links using the Tavily API""" @@ -18,10 +124,26 @@ class APITavilySearch(APISearchEngineLinkSearch): API_KEY_VAR = "TAVILY_API_KEY" """Environment variable that should contain the Tavily API key""" + def __init__(self, api_key=None, verify=False): + """ + + Parameters + ---------- + api_key : str, optional + API key for serper search API. If ``None``, will look up the + API key using the ``"TAVILY_API_KEY"`` environment variable. + By default, ``None``. + verify : bool, default=False + Option to use SSL verification when making request to API + endpoint. By default, ``False``. + """ + super().__init__(api_key=api_key) + self.verify = verify + async def _search(self, query, num_results=10): """Search web for links related to a query""" - client = TavilyClient(api_key=self.api_key) + client = _PatchedTavilyClient(api_key=self.api_key, verify=self.verify) response = client.search(query=query, max_results=num_results) results = response.get("results", []) return list(filter(None, (info.get('url', "").replace("+", "%20") diff --git a/elm/web/website_crawl.py b/elm/web/website_crawl.py index a167c56b..01c6962d 100644 --- a/elm/web/website_crawl.py +++ b/elm/web/website_crawl.py @@ -586,4 +586,6 @@ async def _doc_from_result(self, result): def _compute_avg_score(results): """Compute the average score of the crawled results""" + if len(results) <= 0: + return 0 return sum(r.metadata.get('score', 0) for r in results) / len(results)