Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion elm/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
ELM version number
"""

__version__ = "0.0.23"
__version__ = "0.0.24"
20 changes: 18 additions & 2 deletions elm/web/search/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from contextlib import asynccontextmanager

from camoufox.async_api import AsyncCamoufox
from apiclient.discovery import build
from googleapiclient.discovery import build
from playwright.async_api import TimeoutError as PlaywrightTimeoutError

from elm.web.search.base import (PlaywrightSearchEngineLinkSearch,
Expand Down Expand Up @@ -244,6 +244,22 @@ class APISerperSearch(APISearchEngineLinkSearch):
API_KEY_VAR = "SERPER_API_KEY"
"""Environment variable that should contain the Google Serper API key"""

def __init__(self, api_key=None, verify=False):
"""

Parameters
----------
api_key : str, optional
API key for serper search API. If ``None``, will look up the
API key using the ``"SERPER_API_KEY"`` environment variable.
By default, ``None``.
verify : bool, default=False
Option to use SSL verification when making request to API
endpoint. By default, ``False``.
"""
super().__init__(api_key=api_key)
self.verify = verify

async def _search(self, query, num_results=10):
"""Search web for links related to a query"""

Expand All @@ -252,7 +268,7 @@ async def _search(self, query, num_results=10):
'Content-Type': 'application/json'}

response = requests.request("POST", self._URL, headers=headers,
data=payload)
data=payload, verify=self.verify)
results = json.loads(response.text).get('organic', {})
return list(filter(None, (result.get("link", "").replace("+", "%20")
for result in results)))
90 changes: 76 additions & 14 deletions elm/web/search/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ async def web_search_links_as_docs(queries, search_engines=_DEFAULT_SE,
num_urls=None, ignore_url_parts=None,
search_semaphore=None,
browser_semaphore=None, task_name=None,
use_fallback_per_query=True,
on_search_complete_hook=None,
**kwargs):
"""Retrieve top ``N`` search results as document instances
Expand All @@ -75,7 +76,9 @@ async def web_search_links_as_docs(queries, search_engines=_DEFAULT_SE,
is used and so on. If all web searches fail, an empty list is
returned. See :obj:`~elm.web.search.run.SEARCH_ENGINE_OPTIONS`
for supported search engine options.
By default, ``("PlaywrightGoogleLinkSearch", )``.
By default, ``("PlaywrightGoogleLinkSearch",
"PlaywrightDuckDuckGoLinkSearch",
"DuxDistributedGlobalSearch")``.
num_urls : int, optional
Number of unique top Google search result to return as docs. The
google search results from all queries are interleaved and the
Expand All @@ -101,6 +104,13 @@ async def web_search_links_as_docs(queries, search_engines=_DEFAULT_SE,
task_name : str, optional
Optional task name to use in :func:`asyncio.create_task`.
By default, ``None``.
use_fallback_per_query : bool, default=True
Option to use the fallback list of search engines on a per-query
basis. This means if a single query fails with one search
engine, the fallback search engines will be attempted for that
query. If this input is ``False``, the fallback search engines
are only used if *all* search queries fail for a single search
engine. By default, ``True``.
on_search_complete_hook : callable, optional
If provided, this async callable will be called after the search
engine links have been retrieved. A single argument will be
Expand Down Expand Up @@ -145,11 +155,13 @@ async def web_search_links_as_docs(queries, search_engines=_DEFAULT_SE,
# backward-compatibility
search_semaphore = browser_semaphore

fpq = use_fallback_per_query
urls = await search_with_fallback(queries, search_engines=search_engines,
num_urls=num_urls,
ignore_url_parts=ignore_url_parts,
browser_semaphore=search_semaphore,
task_name=task_name, **kwargs)
task_name=task_name,
use_fallback_per_query=fpq, **kwargs)
if on_search_complete_hook is not None:
await on_search_complete_hook(urls)

Expand All @@ -161,7 +173,7 @@ async def web_search_links_as_docs(queries, search_engines=_DEFAULT_SE,
async def search_with_fallback(queries, search_engines=_DEFAULT_SE,
num_urls=None, ignore_url_parts=None,
browser_semaphore=None, task_name=None,
**kwargs):
use_fallback_per_query=True, **kwargs):
"""Retrieve search query URLs using multiple search engines if needed

Parameters
Expand Down Expand Up @@ -198,6 +210,13 @@ async def search_with_fallback(queries, search_engines=_DEFAULT_SE,
task_name : str, optional
Optional task name to use in :func:`asyncio.create_task`.
By default, ``None``.
use_fallback_per_query : bool, default=True
Option to use the fallback list of search engines on a per-query
basis. This means if a single query fails with one search
engine, the fallback search engines will be attempted for that
query. If this input is ``False``, the fallback search engines
are only used if *all* search queries fail for a single search
engine. By default, ``True``.
**kwargs
Keyword-argument pairs to initialize search engines. This input
can include and any/all of the following keywords:
Expand Down Expand Up @@ -241,13 +260,20 @@ async def search_with_fallback(queries, search_engines=_DEFAULT_SE,
logger.error(msg)
raise ELMInputError(msg)

for se_name in search_engines:
logger.debug("Searching web using %r", se_name)
urls = await _single_se_search(se_name, queries, num_urls,
ignore_url_parts, browser_semaphore,
task_name, kwargs)
if use_fallback_per_query:
urls = await _multi_se_search(search_engines, queries, num_urls,
ignore_url_parts, browser_semaphore,
task_name, kwargs)
if urls:
return urls
else:
for se_name in search_engines:
logger.debug("Searching web using %r", se_name)
urls = await _single_se_search(se_name, queries, num_urls,
ignore_url_parts, browser_semaphore,
task_name, kwargs)
if urls:
return urls

logger.warning("No web results found using %d search engines: %r",
len(search_engines), search_engines)
Expand Down Expand Up @@ -293,17 +319,44 @@ async def load_docs(urls, browser_semaphore=None, **kwargs):
async def _single_se_search(se_name, queries, num_urls, ignore_url_parts,
browser_sem, task_name, kwargs):
"""Search for links using a single search engine"""
if se_name not in SEARCH_ENGINE_OPTIONS:
msg = (f"'se_name' must be one of: {list(SEARCH_ENGINE_OPTIONS)}\n"
f"Got {se_name=}")
logger.error(msg)
raise ELMKeyError(msg)

_validate_se_name(se_name)
links = await _run_search(se_name, queries, browser_sem, task_name, kwargs)
return _down_select_urls(links, num_urls=num_urls,
ignore_url_parts=ignore_url_parts)


async def _multi_se_search(search_engines, queries, num_urls,
ignore_url_parts, browser_sem, task_name, kwargs):
"""Search for links using one or more search engines as fallback"""
outputs = {q: None for q in queries}
remaining_queries = list(queries)
for se_name in search_engines:
_validate_se_name(se_name)

logger.debug("Searching web using %r", se_name)
links = await _run_search(se_name, remaining_queries, browser_sem,
task_name, kwargs)
logger.trace("Links: %r", links)

failed_queries = []
for q, se_result in zip(remaining_queries, links):
if not se_result or not se_result[0]:
failed_queries.append(q)
continue
outputs[q] = se_result

remaining_queries = failed_queries
logger.trace("Remaining queries to search: %r", remaining_queries)

if not remaining_queries:
break

links = [link or [[]] for link in outputs.values()]

return _down_select_urls(links, num_urls=num_urls,
ignore_url_parts=ignore_url_parts)


async def _run_search(se_name, queries, browser_sem, task_name, kwargs):
"""Run a search for multiple queries on a single search engine"""
searchers = [asyncio.create_task(_single_query_search(se_name, query,
Expand Down Expand Up @@ -383,3 +436,12 @@ def _as_set(user_input):
if isinstance(user_input, str):
user_input = {user_input}
return set(user_input or [])


def _validate_se_name(se_name):
"""Validate user search engine name input"""
if se_name not in SEARCH_ENGINE_OPTIONS:
msg = (f"'se_name' must be one of: {list(SEARCH_ENGINE_OPTIONS)}\n"
f"Got {se_name=}")
logger.error(msg)
raise ELMKeyError(msg)
124 changes: 123 additions & 1 deletion elm/web/search/tavily.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,121 @@
# -*- coding: utf-8 -*-
"""ELM Web Scraping - Tavily API search"""
import json
import logging
import requests

from tavily import TavilyClient
from tavily.errors import (UsageLimitExceededError, InvalidAPIKeyError,
BadRequestError, ForbiddenError)

from elm.web.search.base import APISearchEngineLinkSearch


logger = logging.getLogger(__name__)


class _PatchedTavilyClient(TavilyClient):
"""Patch `TavilyClient` to accept verify keyword"""

def __init__(self, api_key=None, proxies=None, verify=False):
"""

Parameters
----------
api_key : str, optional
API key for search engine. If ``None``, will look up the API
key using the ``"TAVILY_API_KEY"`` environment variable.
By default, ``None``.
verify : bool, default=False
Option to use SSL verification when making request to API
endpoint. By default, ``False``.
"""
super().__init__(api_key=api_key, proxies=proxies)
self.verify = verify

def _search(self, query, search_depth="basic", topic="general",
time_range=None, days=7, max_results=5, include_domains=None,
exclude_domains=None, include_answer=False,
include_raw_content=False, include_images=False, timeout=60,
**kwargs):
"""Internal search method to send the request to the API"""

data = {"query": query, "search_depth": search_depth, "topic": topic,
"time_range": time_range, "days": days,
"include_answer": include_answer,
"include_raw_content": include_raw_content,
"max_results": max_results, "include_domains": include_domains,
"exclude_domains": exclude_domains,
"include_images": include_images}

if kwargs:
data.update(kwargs)

timeout = min(timeout, 120)

response = requests.post(self.base_url + "/search",
data=json.dumps(data), headers=self.headers,
timeout=timeout, proxies=self.proxies,
verify=self.verify)

if response.status_code == 200:
return response.json()
else:
detail = ""
try:
detail = response.json().get("detail", {}).get("error", None)
except Exception:
pass

if response.status_code == 429:
raise UsageLimitExceededError(detail)
elif response.status_code in [403,432,433]:
raise ForbiddenError(detail)
elif response.status_code == 401:
raise InvalidAPIKeyError(detail)
elif response.status_code == 400:
raise BadRequestError(detail)
else:
raise response.raise_for_status()

def _extract(self, urls, include_images=False, extract_depth="basic",
timeout=60, **kwargs):
"""
Internal extract method to send the request to the API.
"""
data = {"urls": urls, "include_images": include_images,
"extract_depth": extract_depth}
if kwargs:
data.update(kwargs)

timeout = min(timeout, 120)

response = requests.post(self.base_url + "/extract",
data=json.dumps(data), headers=self.headers,
timeout=timeout, proxies=self.proxies,
verify=self.verify)

if response.status_code == 200:
return response.json()
else:
detail = ""
try:
detail = response.json().get("detail", {}).get("error", None)
except Exception:
pass

if response.status_code == 429:
raise UsageLimitExceededError(detail)
elif response.status_code in [403,432,433]:
raise ForbiddenError(detail)
elif response.status_code == 401:
raise InvalidAPIKeyError(detail)
elif response.status_code == 400:
raise BadRequestError(detail)
else:
raise response.raise_for_status()


class APITavilySearch(APISearchEngineLinkSearch):
"""Search the web for links using the Tavily API"""

Expand All @@ -18,10 +124,26 @@ class APITavilySearch(APISearchEngineLinkSearch):
API_KEY_VAR = "TAVILY_API_KEY"
"""Environment variable that should contain the Tavily API key"""

def __init__(self, api_key=None, verify=False):
"""

Parameters
----------
api_key : str, optional
API key for serper search API. If ``None``, will look up the
API key using the ``"TAVILY_API_KEY"`` environment variable.
By default, ``None``.
verify : bool, default=False
Option to use SSL verification when making request to API
endpoint. By default, ``False``.
"""
super().__init__(api_key=api_key)
self.verify = verify

async def _search(self, query, num_results=10):
"""Search web for links related to a query"""

client = TavilyClient(api_key=self.api_key)
client = _PatchedTavilyClient(api_key=self.api_key, verify=self.verify)
response = client.search(query=query, max_results=num_results)
results = response.get("results", [])
return list(filter(None, (info.get('url', "").replace("+", "%20")
Expand Down
2 changes: 2 additions & 0 deletions elm/web/website_crawl.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,4 +586,6 @@ async def _doc_from_result(self, result):

def _compute_avg_score(results):
"""Compute the average score of the crawled results"""
if len(results) <= 0:
return 0
return sum(r.metadata.get('score', 0) for r in results) / len(results)
Loading