Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(serper-dev): restore search localization parameters #197

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion crewai_tools/tools/serper_dev_tool/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ from crewai_tools import SerperDevTool
tool = SerperDevTool(
n_results=10, # Optional: Number of results to return (default: 10)
save_file=False, # Optional: Save results to file (default: False)
search_type="search" # Optional: Type of search - "search" or "news" (default: "search")
search_type="search", # Optional: Type of search - "search" or "news" (default: "search")
country="us", # Optional: Country for search (default: "")
location="New York", # Optional: Location for search (default: "")
locale="en-US" # Optional: Locale for search (default: "")
)

# Execute a search
Expand Down
16 changes: 14 additions & 2 deletions crewai_tools/tools/serper_dev_tool/serper_dev_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import json
import logging
import os
from typing import Any, Type
from typing import Any, Type, Optional

import requests
from crewai.tools import BaseTool
Expand Down Expand Up @@ -45,6 +45,9 @@ class SerperDevTool(BaseTool):
n_results: int = 10
save_file: bool = False
search_type: str = "search"
country: Optional[str] = ""
location: Optional[str] = ""
locale: Optional[str] = ""

def _get_search_url(self, search_type: str) -> str:
"""Get the appropriate endpoint URL based on search type."""
Expand Down Expand Up @@ -146,11 +149,20 @@ def _process_news_results(self, news_results: list) -> list:
def _make_api_request(self, search_query: str, search_type: str) -> dict:
"""Make API request to Serper."""
search_url = self._get_search_url(search_type)
payload = json.dumps({"q": search_query, "num": self.n_results})
payload = {"q": search_query, "num": self.n_results}

if self.country != "":
payload["gl"] = self.country
if self.location != "":
payload["location"] = self.location
if self.locale != "":
payload["hl"] = self.locale

headers = {
"X-API-KEY": os.environ["SERPER_API_KEY"],
"content-type": "application/json",
}
payload = json.dumps(payload)

response = None
try:
Expand Down
151 changes: 151 additions & 0 deletions tests/tools/serper_dev_tool_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
from unittest.mock import patch
import pytest
from crewai_tools.tools.serper_dev_tool.serper_dev_tool import SerperDevTool
import os


@pytest.fixture(autouse=True)
def mock_serper_api_key():
with patch.dict(os.environ, {"SERPER_API_KEY": "test_key"}):
yield


@pytest.fixture
def serper_tool():
return SerperDevTool(n_results=2)


def test_serper_tool_initialization():
tool = SerperDevTool()
assert tool.n_results == 10
assert tool.save_file is False
assert tool.search_type == "search"
assert tool.country == ""
assert tool.location == ""
assert tool.locale == ""


def test_serper_tool_custom_initialization():
tool = SerperDevTool(
n_results=5,
save_file=True,
search_type="news",
country="US",
location="New York",
locale="en"
)
assert tool.n_results == 5
assert tool.save_file is True
assert tool.search_type == "news"
assert tool.country == "US"
assert tool.location == "New York"
assert tool.locale == "en"


@patch("requests.post")
def test_serper_tool_search(mock_post):
tool = SerperDevTool(n_results=2)
mock_response = {
"searchParameters": {
"q": "test query",
"type": "search"
},
"organic": [
{
"title": "Test Title 1",
"link": "http://test1.com",
"snippet": "Test Description 1",
"position": 1
},
{
"title": "Test Title 2",
"link": "http://test2.com",
"snippet": "Test Description 2",
"position": 2
}
],
"peopleAlsoAsk": [
{
"question": "Test Question",
"snippet": "Test Answer",
"title": "Test Source",
"link": "http://test.com"
}
]
}
mock_post.return_value.json.return_value = mock_response
mock_post.return_value.status_code = 200

result = tool.run(search_query="test query")

assert "searchParameters" in result
assert result["searchParameters"]["q"] == "test query"
assert len(result["organic"]) == 2
assert result["organic"][0]["title"] == "Test Title 1"


@patch("requests.post")
def test_serper_tool_news_search(mock_post):
tool = SerperDevTool(n_results=2, search_type="news")
mock_response = {
"searchParameters": {
"q": "test news",
"type": "news"
},
"news": [
{
"title": "News Title 1",
"link": "http://news1.com",
"snippet": "News Description 1",
"date": "2024-01-01",
"source": "News Source 1",
"imageUrl": "http://image1.com"
}
]
}
mock_post.return_value.json.return_value = mock_response
mock_post.return_value.status_code = 200

result = tool.run(search_query="test news")

assert "news" in result
assert len(result["news"]) == 1
assert result["news"][0]["title"] == "News Title 1"


@patch("requests.post")
def test_serper_tool_with_location_params(mock_post):
tool = SerperDevTool(
n_results=2,
country="US",
location="New York",
locale="en"
)

tool.run(search_query="test")

called_payload = mock_post.call_args.kwargs["json"]
assert called_payload["gl"] == "US"
assert called_payload["location"] == "New York"
assert called_payload["hl"] == "en"


def test_invalid_search_type():
tool = SerperDevTool()
with pytest.raises(ValueError) as exc_info:
tool.run(search_query="test", search_type="invalid")
assert "Invalid search type" in str(exc_info.value)


@patch("requests.post")
def test_api_error_handling(mock_post):
tool = SerperDevTool()
mock_post.side_effect = Exception("API Error")

with pytest.raises(Exception) as exc_info:
tool.run(search_query="test")
assert "API Error" in str(exc_info.value)


if __name__ == "__main__":
pytest.main([__file__])