diff --git a/README.md b/README.md index 35a2edf5a6..ba640d4b06 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,7 @@ Toolkit is a deployable all-in-one RAG application that enables users to quickly - [How to setup Gmail](/docs/custom_tool_guides/gmail.md) - [How to setup Slack Tool](/docs/custom_tool_guides/slack.md) - [How to setup Github Tool](/docs/custom_tool_guides/github.md) + - [How to setup Sharepoint](/docs/custom_tool_guides/sharepoint.md) - [How to setup Google Text-to-Speech](/docs/text_to_speech.md) - [How to add authentication](/docs/auth_guide.md) - [How to deploy toolkit services](/docs/service_deployments.md) @@ -28,15 +29,20 @@ Toolkit is a deployable all-in-one RAG application that enables users to quickly ![](/docs/assets/toolkit.gif) ## Try Now: -There are two main ways for quickly running Toolkit: local and cloud. See the specific instructions given below. + +There are two main ways for quickly running Toolkit: local and cloud. See the specific instructions given below. + ### Local -*You will need to have [Docker](https://www.docker.com/products/docker-desktop/), [Docker-compose >= 2.22](https://docs.docker.com/compose/install/), and [Poetry](https://python-poetry.org/docs/#installation) installed. [Go here for a more detailed setup.](/docs/setup.md)* + +_You will need to have [Docker](https://www.docker.com/products/docker-desktop/), [Docker-compose >= 2.22](https://docs.docker.com/compose/install/), and [Poetry](https://python-poetry.org/docs/#installation) installed. [Go here for a more detailed setup.](/docs/setup.md)_ Note: to include community tools when building locally, set the `INSTALL_COMMUNITY_DEPS` build arg in the `docker-compose.yml` to `true`. Both options will serve the frontend at http://localhost:4000. #### Using `make` + Use the provided Makefile to simplify and automate your development workflow with Cohere Toolkit, including Docker Compose management, testing, linting, and environment setup. + ```bash git clone https://github.com/cohere-ai/cohere-toolkit.git cd cohere-toolkit @@ -44,14 +50,18 @@ make first-run ``` #### Docker Compose only + Use Docker Compose directly if you want to quickly spin up and manage your container environment without the additional automation provided by the Makefile. + ```bash git clone https://github.com/cohere-ai/cohere-toolkit.git cd cohere-toolkit docker compose up docker compose run --build backend alembic -c src/backend/alembic.ini upgrade head ``` + ### Cloud + #### GitHub Codespaces To run this project using GitHub Codespaces, please refer to our [Codespaces Setup Guide](/docs/github_codespaces.md). @@ -63,7 +73,7 @@ To run this project using GitHub Codespaces, please refer to our [Codespaces Set - **Interfaces** - any client-side UI, currently contains two web apps, one agentic and one basic, and a Slack bot implementation. - Defaults to Cohere's Web UI at `src/interfaces/assistants_web` - A web app built in Next.js. Includes a simple SQL database out of the box to store conversation history in the app. - You can change the Web UI using the docker compose file. -- **Backend API** - in `src/backend` this follows a similar structure to the [Cohere Chat API](https://docs.cohere.com/reference/chat) but also include customizable elements: +- **Backend API** - in `src/backend` this follows a similar structure to the [Cohere Chat API](https://docs.cohere.com/reference/chat) but also include customizable elements: - **Model** - you can customize with which provider you access Cohere's Command models. By default included in the toolkit is Cohere's Platform, Sagemaker, Azure, Bedrock, HuggingFace, local models. [More details here.](/docs/command_model_providers.md) - **Retrieval**- you can customize tools and data sources that the application is run with. - **Service Deployment Guides** - we also include guides for how to deploy the toolkit services in production including with AWS, GCP and Azure. [More details here.](/docs/service_deployments.md) diff --git a/docs/custom_tool_guides/sharepoint.md b/docs/custom_tool_guides/sharepoint.md new file mode 100644 index 0000000000..0ac44bed13 --- /dev/null +++ b/docs/custom_tool_guides/sharepoint.md @@ -0,0 +1,39 @@ +# Sharepoint Tool Setup + +To setup the Sharepoint tool you need to configure API access via the following steps + +## 1. Configure Tenant ID and Client ID + +Your Microsoft Tenant ID and Client ID can be found my navigating to the [Micorsoft Entra Admin Center](https://entra.microsoft.com/) and then going to the `Overview` Page under the `Identity Section`. There the Tenant ID is listed as Tenant ID, and the Client ID is listed as the Application ID. + +Copy your Tenant ID into the `configuration.yaml` file in the config directory of the backend, and your Client ID into the `secrets.yaml` file in the config directory of the backend. + +## 2. Register New Application + +Navigate to the `App registration` page under `Applications` on the same [Micorsoft Entra Admin Center](https://entra.microsoft.com/) website. + +Click `New registration` to register a new application. Enter a name and select the proper account type. Single tenant is the norm unless you know of otherwise. + +Under redirect URI select Web as the path should be `/v1/tool/auth`. For example: + +```bash + https:///v1/tool/auth +``` + +Click `Register` to Complete the Application Registration + +## 3. Configure Permissions + +Under the newly registered application navigate to the `API permissions` page. There you need to Click `Add a permission`, select `Microsoft Graph`, then `delegated permissions`. Next search `files.read.all` and check the box, then search `sites.read.all` and check the box. Then Click `Add permissions`. + +## 3. Configure Client Secret + +Under the newly registered application navigate to the `Certificates & secrets` page. Click `New client secret`, enter a description and an expiry then click `Add`. Your new Client Secret is only available to copy under the `value` column of the table right now. Copy it into the `secrets.yaml` file in the config directory of the backend. + +## 5. Run the Backend and Frontend + +run next command to start the backend and frontend: + +```bash +make dev +``` diff --git a/src/backend/config/configuration.template.yaml b/src/backend/config/configuration.template.yaml index 767edb5012..34d04daba5 100644 --- a/src/backend/config/configuration.template.yaml +++ b/src/backend/config/configuration.template.yaml @@ -42,7 +42,8 @@ tools: - public_repo default_repos: - cohere-ai/cohere-toolkit - - EugeneLightsOn/cohere-toolkit + sharepoint: + tenant_id: # To disable the use of the tools preamble, set it to false use_tools_preamble: true feature_flags: diff --git a/src/backend/config/secrets.template.yaml b/src/backend/config/secrets.template.yaml index c4f22b23ed..f0b41e75cf 100644 --- a/src/backend/config/secrets.template.yaml +++ b/src/backend/config/secrets.template.yaml @@ -40,6 +40,9 @@ tools: github: client_id: client_secret: + sharepoint: + client_id: + client_secret: auth: secret_key: google_oauth: @@ -53,4 +56,4 @@ auth: client_secret: well_known_endpoint: google_cloud: - api_key: \ No newline at end of file + api_key: diff --git a/src/backend/config/settings.py b/src/backend/config/settings.py index 189bdd441e..d8d73f4ea7 100644 --- a/src/backend/config/settings.py +++ b/src/backend/config/settings.py @@ -271,6 +271,22 @@ class HybridWebSearchSettings(BaseSettings, BaseModel): site_filters: Optional[List[str]] = [] +class SharepointSettings(BaseSettings, BaseModel): + model_config = SETTINGS_CONFIG + tenant_id: Optional[str] = Field( + default=None, + validation_alias=AliasChoices("SHAREPOINT_TENANT_ID", "tenant_id"), + ) + client_id: Optional[str] = Field( + default=None, + validation_alias=AliasChoices("SHAREPOINT_CLIENT_ID", "client_id"), + ) + client_secret: Optional[str] = Field( + default=None, + validation_alias=AliasChoices("SHAREPOINT_CLIENT_SECRET", "client_secret"), + ) + + class ToolSettings(BaseSettings, BaseModel): model_config = SETTINGS_CONFIG @@ -302,6 +318,9 @@ class ToolSettings(BaseSettings, BaseModel): gmail: Optional[GmailSettings] = Field( default=GmailSettings() ) + sharepoint: Optional[SharepointSettings] = Field( + default=SharepointSettings() + ) use_tools_preamble: Optional[bool] = Field( default=False, validation_alias=AliasChoices("USE_TOOLS_PREAMBLE", "use_tools_preamble") diff --git a/src/backend/config/tools.py b/src/backend/config/tools.py index 742f8339f5..d97e77f981 100644 --- a/src/backend/config/tools.py +++ b/src/backend/config/tools.py @@ -14,6 +14,7 @@ PythonInterpreter, ReadFileTool, SearchFileTool, + SharepointTool, SlackTool, TavilyWebSearch, WebScrapeTool, @@ -40,6 +41,7 @@ class Tool(Enum): Slack = SlackTool Gmail = GmailTool Github = GithubTool + Sharepoint = SharepointTool def get_available_tools() -> dict[str, ToolDefinition]: diff --git a/src/backend/tests/unit/tools/test_lang_chain.py b/src/backend/tests/unit/tools/test_lang_chain.py index 31ede1ce27..bef392a06c 100644 --- a/src/backend/tests/unit/tools/test_lang_chain.py +++ b/src/backend/tests/unit/tools/test_lang_chain.py @@ -72,7 +72,13 @@ async def test_wiki_retriever_no_docs() -> None: ): result = await retriever.call({"query": query}, ctx) - assert result == [ToolError(type=ToolErrorCode.OTHER, success=False, text='No results found.', details='No results found for the given params.').model_dump()] + expected_error = ToolError( + type=ToolErrorCode.OTHER, + success=False, + text='No results found.', + details='No results found for the given params.' + ).model_dump() + assert result == [expected_error] @@ -156,4 +162,10 @@ async def test_vector_db_retriever_no_docs() -> None: mock_db.as_retriever().get_relevant_documents.return_value = mock_docs result = await retriever.call({"query": query}, ctx) - assert result == [ToolError(type=ToolErrorCode.OTHER, success=False, text='No results found.', details='No results found for the given params.').model_dump()] + expected_error = ToolError( + type=ToolErrorCode.OTHER, + success=False, + text='No results found.', + details='No results found for the given params.' + ).model_dump() + assert result == [expected_error] diff --git a/src/backend/tools/__init__.py b/src/backend/tools/__init__.py index 9fed5364da..c9c99b4b8a 100644 --- a/src/backend/tools/__init__.py +++ b/src/backend/tools/__init__.py @@ -1,12 +1,14 @@ from backend.tools.brave_search import BraveWebSearch from backend.tools.calculator import Calculator from backend.tools.files import ReadFileTool, SearchFileTool +from backend.tools.github import GithubAuth, GithubTool from backend.tools.gmail import GmailAuth, GmailTool from backend.tools.google_drive import GoogleDrive, GoogleDriveAuth from backend.tools.google_search import GoogleWebSearch from backend.tools.hybrid_search import HybridWebSearch from backend.tools.lang_chain import LangChainVectorDBRetriever, LangChainWikiRetriever from backend.tools.python_interpreter import PythonInterpreter +from backend.tools.sharepoint import SharepointAuth, SharepointTool from backend.tools.slack import SlackAuth, SlackTool from backend.tools.tavily_search import TavilyWebSearch from backend.tools.web_scrape import WebScrapeTool @@ -29,4 +31,8 @@ "SlackAuth", "GmailTool", "GmailAuth", + "SharepointTool", + "SharepointAuth", + "GithubTool", + "GithubAuth", ] diff --git a/src/backend/tools/base.py b/src/backend/tools/base.py index d52f615526..e935997173 100644 --- a/src/backend/tools/base.py +++ b/src/backend/tools/base.py @@ -1,7 +1,7 @@ import datetime from abc import ABC, abstractmethod from enum import StrEnum -from typing import Any, Dict, List +from typing import Any import requests from fastapi import Request @@ -11,6 +11,7 @@ from backend.crud import tool_auth as tool_auth_crud from backend.database_models.database import DBSessionDep from backend.database_models.tool_auth import ToolAuth +from backend.schemas.context import Context from backend.schemas.tool import ToolDefinition from backend.services.logger.utils import LoggerFactory from backend.tools.utils.tools_checkers import check_tool_parameters @@ -157,20 +158,22 @@ def _handle_tool_specific_errors(cls, error: Exception, **kwargs: Any) -> None: ... @classmethod - def get_tool_error(cls, details: str, text: str = "Error calling tool", error_type: ToolErrorCode = ToolErrorCode.OTHER): + def get_tool_error( + cls, details: str, text: str = "Error calling tool", error_type: ToolErrorCode = ToolErrorCode.OTHER, + ) -> list[dict[str, str]]: tool_error = ToolError(text=f"{text} {cls.ID}.", details=details, type=error_type).model_dump() logger.error(event=f"Error calling tool {cls.ID}", error=tool_error) return [tool_error] @classmethod - def get_no_results_error(cls): + def get_no_results_error(cls) -> list[dict[str, str]]: tool_error = ToolError(text="No results found.", details="No results found for the given params.").model_dump() return [tool_error] @abstractmethod async def call( - self, parameters: dict, ctx: Any, **kwargs: Any - ) -> List[Dict[str, Any]]: + self, parameters: dict, ctx: Context, **kwargs: Any, + ) -> list[dict[str, Any]]: ... @classmethod @@ -248,7 +251,7 @@ def try_refresh_token( @abstractmethod def retrieve_auth_token( self, request: Request, session: DBSessionDep, user_id: str - ) -> str: + ) -> str|None: ... def get_token(self, session: DBSessionDep, user_id: str) -> str: diff --git a/src/backend/tools/brave_search/tool.py b/src/backend/tools/brave_search/tool.py index 4e424b5f40..c631e00510 100644 --- a/src/backend/tools/brave_search/tool.py +++ b/src/backend/tools/brave_search/tool.py @@ -1,7 +1,7 @@ -from typing import Any, Dict, List +from typing import Any from backend.config.settings import Settings -from backend.database_models.database import DBSessionDep +from backend.schemas.context import Context from backend.schemas.tool import ToolCategory, ToolDefinition from backend.tools.base import BaseTool, ToolArgument from backend.tools.brave_search.client import BraveClient @@ -40,11 +40,11 @@ def get_tool_definition(cls) -> ToolDefinition: "Returns a list of relevant document snippets for a textual query retrieved " "from the internet using Brave Search." ), - ) + ) # type: ignore async def call( - self, parameters: dict, ctx: Any, session: DBSessionDep, **kwargs: Any - ) -> List[Dict[str, Any]]: + self, parameters: dict, ctx: Context, **kwargs: Any + ) -> list[dict[str, Any]]: query = parameters.get("query", "") # Get domain filtering from kwargs diff --git a/src/backend/tools/calculator.py b/src/backend/tools/calculator.py index 5de2c1dcda..f89b5fc0bf 100644 --- a/src/backend/tools/calculator.py +++ b/src/backend/tools/calculator.py @@ -1,7 +1,8 @@ -from typing import Any, Dict, List +from typing import Any from py_expression_eval import Parser +from backend.schemas.context import Context from backend.schemas.tool import ToolCategory, ToolDefinition from backend.tools.base import BaseTool @@ -35,11 +36,11 @@ def get_tool_definition(cls) -> ToolDefinition: category=ToolCategory.Function, error_message=cls.generate_error_message(), description="A powerful multi-purpose calculator capable of a wide array of math calculations.", - ) + ) # type: ignore async def call( - self, parameters: dict, ctx: Any, **kwargs: Any - ) -> List[Dict[str, Any]]: + self, parameters: dict, ctx: Context, **kwargs: Any + ) -> list[dict[str, Any]]: logger = ctx.get_logger() math_parser = Parser() diff --git a/src/backend/tools/files.py b/src/backend/tools/files.py index 707d4cb1ec..3d94daf34e 100644 --- a/src/backend/tools/files.py +++ b/src/backend/tools/files.py @@ -1,7 +1,8 @@ from enum import StrEnum -from typing import Any, Dict, List +from typing import Any import backend.crud.file as file_crud +from backend.schemas.context import Context from backend.schemas.tool import ToolCategory, ToolDefinition from backend.tools.base import BaseTool @@ -43,16 +44,19 @@ def get_tool_definition(cls) -> ToolDefinition: error_message=cls.generate_error_message(), category=ToolCategory.FileLoader, description="Returns the chunked textual contents of an uploaded file.", - ) + ) # type: ignore + @classmethod def get_info(cls) -> ToolDefinition: return ToolDefinition( display_name="Calculator", description="A powerful multi-purpose calculator capable of a wide array of math calculations.", error_message=cls.generate_error_message(), - ) + ) # type: ignore - async def call(self, parameters: dict, **kwargs: Any) -> List[Dict[str, Any]]: + async def call( + self, parameters: dict, ctx: Context, **kwargs: Any, + ) -> list[dict[str, Any]]: file = parameters.get("file") session = kwargs.get("session") @@ -113,16 +117,16 @@ def get_tool_definition(cls) -> ToolDefinition: error_message=cls.generate_error_message(), category=ToolCategory.FileLoader, description="Searches across one or more attached files based on a textual search query.", - ) + ) # type: ignore async def call( - self, parameters: dict, ctx: Any, **kwargs: Any - ) -> List[Dict[str, Any]]: + self, parameters: dict, ctx: Context, **kwargs: Any, + ) -> list[dict[str, Any]]: query = parameters.get("search_query") files = parameters.get("files") session = kwargs.get("session") - user_id = kwargs.get("user_id") + user_id = kwargs.get("user_id", "") if not query or not files: return self.get_tool_error( diff --git a/src/backend/tools/github/__init__.py b/src/backend/tools/github/__init__.py index e69de29bb2..7c64770917 100644 --- a/src/backend/tools/github/__init__.py +++ b/src/backend/tools/github/__init__.py @@ -0,0 +1,11 @@ +from backend.tools.github.auth import GithubAuth +from backend.tools.github.constants import ( + GITHUB_TOOL_ID, +) +from backend.tools.github.tool import GithubTool + +__all__ = [ + "GithubAuth", + "GithubTool", + "GITHUB_TOOL_ID", +] diff --git a/src/backend/tools/github/tool.py b/src/backend/tools/github/tool.py index e218638626..fc94a4fca7 100644 --- a/src/backend/tools/github/tool.py +++ b/src/backend/tools/github/tool.py @@ -1,10 +1,11 @@ -from typing import Any, Dict, List, Union +from typing import Any from backend.config.settings import Settings from backend.crud import tool_auth as tool_auth_crud +from backend.schemas.context import Context from backend.schemas.tool import ToolCategory, ToolDefinition from backend.services.logger.utils import LoggerFactory -from backend.tools.base import BaseTool, ToolError +from backend.tools.base import BaseTool from backend.tools.github.auth import GithubAuth from backend.tools.github.constants import GITHUB_TOOL_ID, SEARCH_LIMIT from backend.tools.github.utils import get_github_service @@ -63,7 +64,7 @@ def _handle_tool_specific_errors(cls, error: Exception, **kwargs: Any) -> None: ) raise Exception(message) - async def call(self, parameters: dict, ctx: Any, **kwargs: Any) -> Union[List[Dict[str, Any]], ToolError]: + async def call(self, parameters: dict, ctx: Context, **kwargs: Any) -> list[dict[str, Any]]: user_id = kwargs.get("user_id", "") query = parameters.get("query", "") diff --git a/src/backend/tools/gmail/tool.py b/src/backend/tools/gmail/tool.py index 24a43c5028..603f0109a9 100644 --- a/src/backend/tools/gmail/tool.py +++ b/src/backend/tools/gmail/tool.py @@ -1,7 +1,8 @@ -from typing import Any, Dict, List +from typing import Any from backend.config.settings import Settings from backend.crud import tool_auth as tool_auth_crud +from backend.schemas.context import Context from backend.schemas.tool import ToolCategory, ToolDefinition from backend.services.logger.utils import LoggerFactory from backend.tools.base import BaseTool @@ -44,7 +45,7 @@ def get_tool_definition(cls) -> ToolDefinition: error_message=cls.generate_error_message(), category=ToolCategory.DataLoader, description="Returns a list of relevant email snippets from Gmail.", - ) + ) # type: ignore @classmethod def _handle_tool_specific_errors(cls, error: Exception, **kwargs: Any) -> None: @@ -62,7 +63,9 @@ def _handle_tool_specific_errors(cls, error: Exception, **kwargs: Any) -> None: ) raise Exception(message) - async def call(self, parameters: dict, ctx: Any, **kwargs: Any) -> List[Dict[str, Any]]: + async def call( + self, parameters: dict, ctx: Context, **kwargs: Any, + ) -> list[dict[str, Any]]: user_id = kwargs.get("user_id", "") query = parameters.get("query", "") @@ -70,4 +73,8 @@ async def call(self, parameters: dict, ctx: Any, **kwargs: Any) -> List[Dict[str results = gmail_service.search_all(query=query) message_ids = [message["id"] for message in results["messages"]] if "messages" in results else [] messages = gmail_service.retrieve_messages(message_ids) + + if not messages: + return self.get_no_results_error() + return gmail_service.serialize_results(messages) diff --git a/src/backend/tools/google_drive/tool.py b/src/backend/tools/google_drive/tool.py index d5b0a1f8f5..587752582e 100644 --- a/src/backend/tools/google_drive/tool.py +++ b/src/backend/tools/google_drive/tool.py @@ -1,9 +1,10 @@ -from typing import Any, Dict, List +from typing import Any from google.auth.exceptions import RefreshError from backend.config.settings import Settings from backend.crud import tool_auth as tool_auth_crud +from backend.schemas.context import Context from backend.schemas.tool import ToolCategory, ToolDefinition from backend.services.logger.utils import LoggerFactory from backend.tools.base import BaseTool @@ -53,7 +54,7 @@ def get_tool_definition(cls) -> ToolDefinition: error_message=cls.generate_error_message(), category=ToolCategory.DataLoader, description="Returns a list of relevant document snippets from the user's Google drive.", - ) + ) # type: ignore def _handle_tool_specific_errors(self, error: Exception, **kwargs: Any): message = "[Google Drive] Tool Error: {}".format(str(error)) @@ -70,8 +71,10 @@ def _handle_tool_specific_errors(self, error: Exception, **kwargs: Any): ) raise Exception(message) - async def call(self, parameters: dict, **kwargs: Any) -> List[Dict[str, Any]]: - user_id = kwargs.get("user_id") + async def call( + self, parameters: dict, ctx: Context, **kwargs: Any, + ) -> list[dict[str, Any]]: + user_id = kwargs.get("user_id", "") query = parameters.get("query", "").replace("'", "\\'") # Search Google Drive @@ -92,7 +95,7 @@ async def call(self, parameters: dict, **kwargs: Any) -> List[Dict[str, Any]]: async def _default_gdrive_list_files( - user_id: str, query: str, agent_tool_metadata: Dict[str, str] + user_id: str, query: str, agent_tool_metadata: dict[str, str] ): from backend.tools.google_drive.constants import ( DOC_FIELDS, diff --git a/src/backend/tools/google_search.py b/src/backend/tools/google_search.py index 14a7e21dd1..e1840084e9 100644 --- a/src/backend/tools/google_search.py +++ b/src/backend/tools/google_search.py @@ -1,9 +1,9 @@ -from typing import Any, Dict, List +from typing import Any from googleapiclient.discovery import build from backend.config.settings import Settings -from backend.database_models.database import DBSessionDep +from backend.schemas.context import Context from backend.schemas.tool import ToolCategory, ToolDefinition from backend.tools.base import BaseTool, ToolArgument @@ -38,11 +38,11 @@ def get_tool_definition(cls) -> ToolDefinition: error_message=cls.generate_error_message(), category=ToolCategory.WebSearch, description="Returns relevant results by performing a Google web search.", - ) + ) # type: ignore async def call( - self, parameters: dict, ctx: Any, session: DBSessionDep, **kwargs: Any - ) -> List[Dict[str, Any]]: + self, parameters: dict, ctx: Context, **kwargs: Any, + ) -> list[dict[str, Any]]: query = parameters.get("query", "") cse = self.client.cse() diff --git a/src/backend/tools/hybrid_search.py b/src/backend/tools/hybrid_search.py index 32843acef1..fa4276e4cd 100644 --- a/src/backend/tools/hybrid_search.py +++ b/src/backend/tools/hybrid_search.py @@ -3,8 +3,8 @@ from typing import Any, Callable, Dict, List from backend.config.settings import Settings -from backend.database_models.database import DBSessionDep from backend.model_deployments.base import BaseDeployment +from backend.schemas.context import Context from backend.schemas.tool import ToolCategory, ToolDefinition from backend.tools.base import BaseTool, ToolArgument from backend.tools.brave_search.tool import BraveWebSearch @@ -60,7 +60,7 @@ def get_tool_definition(cls) -> ToolDefinition: "Returns a list of relevant document snippets for a textual query " "retrieved from the internet using a mix of any existing Web Search tools." ) - ) + ) # type: ignore @classmethod def get_available_search_tools(cls): @@ -74,13 +74,13 @@ def get_available_search_tools(cls): return available_search_tools def _gather_search_tasks( - self, parameters: dict, ctx: Any, session: DBSessionDep, **kwargs: Any + self, parameters: dict, ctx: Any, **kwargs: Any ) -> List[Callable]: tasks = [] # Add search tool calls for search_tool in self.search_tools: - tasks.append(search_tool.call(parameters, ctx, session, **kwargs)) + tasks.append(search_tool.call(parameters, ctx, **kwargs)) # Add web scrape tool calls filtered_sites = kwargs.get(ToolArgument.SITE_FILTER, []) @@ -90,8 +90,8 @@ def _gather_search_tasks( return tasks async def call( - self, parameters: dict, ctx: Any, session: DBSessionDep, **kwargs: Any - ) -> List[Dict[str, Any]]: + self, parameters: dict, ctx: Context, **kwargs: Any + ) -> list[dict[str, Any]]: # Retrieve query for reranking query = parameters.get("query", "") @@ -100,7 +100,7 @@ async def call( # Handle site filtering -> perform web scraping on sites kwargs[ToolArgument.SITE_FILTER] = self.SITE_FILTER - tasks = self._gather_search_tasks(parameters, ctx, session, **kwargs) + tasks = self._gather_search_tasks(parameters, ctx, **kwargs) # Gather and run searches results = await asyncio.gather(*tasks) diff --git a/src/backend/tools/lang_chain.py b/src/backend/tools/lang_chain.py index 71f12c5d1c..68665b7259 100644 --- a/src/backend/tools/lang_chain.py +++ b/src/backend/tools/lang_chain.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List +from typing import Any from langchain.text_splitter import CharacterTextSplitter from langchain_cohere import CohereEmbeddings @@ -7,6 +7,7 @@ from langchain_community.vectorstores import Chroma from backend.config.settings import Settings +from backend.schemas.context import Context from backend.schemas.tool import ToolCategory, ToolDefinition from backend.tools.base import BaseTool @@ -52,11 +53,11 @@ def get_tool_definition(cls) -> ToolDefinition: error_message=cls.generate_error_message(), category=ToolCategory.DataLoader, description="Retrieves documents from Wikipedia.", - ) + ) # type: ignore async def call( - self, parameters: dict, ctx: Any, **kwargs: Any - ) -> List[Dict[str, Any]]: + self, parameters: dict, ctx: Context, **kwargs: Any, + ) -> list[dict[str, Any]]: wiki_retriever = WikipediaRetriever() query = parameters.get("query", "") try: @@ -113,11 +114,11 @@ def get_tool_definition(cls) -> ToolDefinition: error_message=cls.generate_error_message(), category=ToolCategory.DataLoader, description="Retrieves documents from Wikipedia.", - ) + ) # type: ignore async def call( - self, parameters: dict, ctx: Any, **kwargs: Any - ) -> List[Dict[str, Any]]: + self, parameters: dict, ctx: Context, **kwargs: Any, + ) -> list[dict[str, Any]]: cohere_embeddings = CohereEmbeddings(cohere_api_key=self.COHERE_API_KEY) # Load text files and split into chunks diff --git a/src/backend/tools/python_interpreter.py b/src/backend/tools/python_interpreter.py index e7015703f9..24366b5886 100644 --- a/src/backend/tools/python_interpreter.py +++ b/src/backend/tools/python_interpreter.py @@ -1,10 +1,11 @@ import json -from typing import Any, Dict, Mapping +from typing import Any, Mapping import requests from dotenv import load_dotenv from backend.config.settings import Settings +from backend.schemas.context import Context from backend.schemas.tool import ToolCategory, ToolDefinition from backend.tools.base import BaseTool @@ -50,9 +51,11 @@ def get_tool_definition(cls) -> ToolDefinition: "in a static sandbox without internet access and without interactive mode, " "so print output or save output to a file." ), - ) + ) # type: ignore - async def call(self, parameters: dict, ctx: Any, **kwargs: Any): + async def call( + self, parameters: dict, ctx: Context, **kwargs: Any, + ) -> list[dict[str, Any]]: if not self.INTERPRETER_URL: raise Exception("Python Interpreter tool called while URL not set") @@ -68,7 +71,7 @@ async def call(self, parameters: dict, ctx: Any, **kwargs: Any): return clean_res - def _clean_response(self, result: Any) -> Dict[str, str]: + def _clean_response(self, result: Any) -> list[dict[str, str]]: if "final_expression" in result: result["final_expression"] = str(result["final_expression"]) diff --git a/src/backend/tools/sharepoint/__init__.py b/src/backend/tools/sharepoint/__init__.py new file mode 100644 index 0000000000..f025ff8e51 --- /dev/null +++ b/src/backend/tools/sharepoint/__init__.py @@ -0,0 +1,11 @@ +from backend.tools.sharepoint.auth import SharepointAuth +from backend.tools.sharepoint.constants import ( + SHAREPOINT_TOOL_ID, +) +from backend.tools.sharepoint.tool import SharepointTool + +__all__ = [ + "SharepointAuth", + "SharepointTool", + "SHAREPOINT_TOOL_ID", +] diff --git a/src/backend/tools/sharepoint/auth.py b/src/backend/tools/sharepoint/auth.py new file mode 100644 index 0000000000..2fc96940f8 --- /dev/null +++ b/src/backend/tools/sharepoint/auth.py @@ -0,0 +1,147 @@ +import datetime +import json +import urllib.parse + +import requests +from fastapi import Request + +from backend.config.settings import Settings +from backend.crud import tool_auth as tool_auth_crud +from backend.database_models.database import DBSessionDep +from backend.database_models.tool_auth import ToolAuth as ToolAuthModel +from backend.schemas.tool_auth import UpdateToolAuth +from backend.services.auth.crypto import encrypt +from backend.services.logger.utils import LoggerFactory +from backend.tools.base import BaseToolAuthentication +from backend.tools.sharepoint.constants import SHAREPOINT_TOOL_ID +from backend.tools.utils.mixins import ToolAuthenticationCacheMixin + +logger = LoggerFactory().get_logger() + + +class SharepointAuth(BaseToolAuthentication, ToolAuthenticationCacheMixin): + TOOL_ID = SHAREPOINT_TOOL_ID + AUTH_ENDPOINT = "https://login.microsoftonline.com" + SCOPES = [ + "offline_access", + "https://graph.microsoft.com/.default", + ] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.REDIRECT_URL = f"{self.BACKEND_HOST}/v1/tool/auth" + + self.SHAREPOINT_TENANT_ID = Settings().get('tools.sharepoint.tenant_id') + self.SHAREPOINT_CLIENT_ID = Settings().get('tools.sharepoint.client_id') + self.SHAREPOINT_CLIENT_SECRET = Settings().get('tools.sharepoint.client_secret') + + if not any([self.SHAREPOINT_TENANT_ID, self.SHAREPOINT_CLIENT_ID, self.SHAREPOINT_CLIENT_SECRET]): + raise ValueError( + "SHAREPOINT_TENANT_ID, SHAREPOINT_CLIENT_ID and SHAREPOINT_CLIENT_SECRET must be set to use Sharepoint Tool Auth." + ) + + def get_auth_url(self, user_id: str) -> str: + key = self.insert_tool_auth_cache(user_id, self.TOOL_ID) + state = {"key": key} + + params = { + "response_type": "code", + "client_id": self.SHAREPOINT_CLIENT_ID, + "scope": " ".join(self.SCOPES or []), + "redirect_uri": self.REDIRECT_URL, + "prompt": "select_account", + "state": json.dumps(state), + } + + return f"{self.AUTH_ENDPOINT}/{self.SHAREPOINT_TENANT_ID}/oauth2/v2.0/authorize?{urllib.parse.urlencode(params)}" + + def retrieve_auth_token( + self, request: Request, session: DBSessionDep, user_id: str + ) -> str|None: + url = f"{self.AUTH_ENDPOINT}/{self.SHAREPOINT_TENANT_ID}/oauth2/v2.0/token" + headers = { + "Content-Type": "application/x-www-form-urlencoded", + } + code = request.query_params.get("code", "") + payload = { + "grant_type": "authorization_code", + "code": code, + "scope": " ".join(self.SCOPES), + "redirect_uri": self.REDIRECT_URL, + "client_id": self.SHAREPOINT_CLIENT_ID, + "client_secret": self.SHAREPOINT_CLIENT_SECRET, + } + + error_message = "Error retrieving access token from Sharepoint Tool" + try: + response = requests.post(url, data=payload, headers=headers) + body = response.json() + except Exception as exc: + logger.error(event=f"[Sharepoint] Auth token error: {exc}") + return error_message + if not response.ok: + error = body.get("error") + error_description = body.get("error_description") + error_message = f"{error_message}: {error}. {error_description}" + logger.error(event=f"[Sharepoint] Auth token error: {error_message}") + return error_message + + tool_auth_crud.create_tool_auth( + session, + ToolAuthModel( + user_id=user_id, + tool_id=self.TOOL_ID, + token_type=body["token_type"], + encrypted_access_token=encrypt(body["access_token"]), + encrypted_refresh_token=encrypt(body["refresh_token"]), + expires_at=datetime.datetime.now() + + datetime.timedelta(seconds=int(body["expires_in"])), + ), + ) + + def try_refresh_token( + self, session: DBSessionDep, user_id: str, tool_auth: ToolAuthModel + ) -> bool: + url = f"{self.AUTH_ENDPOINT}/{self.SHAREPOINT_TENANT_ID}/oauth2/v2.0/token" + headers = { + "Content-Type": "application/x-www-form-urlencoded", + } + payload = { + "grant_type": "refresh_token", + "refresh_token": tool_auth.refresh_token, + "scope": " ".join(self.SCOPES), + "client_id": self.SHAREPOINT_CLIENT_ID, + "client_secret": self.SHAREPOINT_CLIENT_SECRET, + } + + error_message = "Error retrieving refreshing token from Sharepoint Tool" + try: + response = requests.post(url, data=payload, headers=headers) + body = response.json() + except Exception as exc: + logger.error(event=f"[Sharepoint] Auth token error: {exc}") + return False + if not response.ok: + error = body.get("error") + error_description = body.get("error_description") + error_message = f"{error_message}: {error}. {error_description}" + logger.error(event=f"[Sharepoint] Auth token error: {error_message}") + return False + + existing_tool_auth = tool_auth_crud.get_tool_auth( + session, self.TOOL_ID, user_id + ) + tool_auth_crud.update_tool_auth( + session, + existing_tool_auth, + UpdateToolAuth( + user_id=user_id, + tool_id=self.TOOL_ID, + token_type=body["token_type"], + encrypted_access_token=encrypt(body["access_token"]), + encrypted_refresh_token=encrypt(body["refresh_token"]), + expires_at=datetime.datetime.now() + + datetime.timedelta(seconds=body["expires_in"]), + ), + ) + return True diff --git a/src/backend/tools/sharepoint/constants.py b/src/backend/tools/sharepoint/constants.py new file mode 100644 index 0000000000..da7008ecbd --- /dev/null +++ b/src/backend/tools/sharepoint/constants.py @@ -0,0 +1,6 @@ +""" +Constants for Sharepoint Tool +""" + +SHAREPOINT_TOOL_ID = "sharepoint" +SEARCH_LIMIT = 10 diff --git a/src/backend/tools/sharepoint/tool.py b/src/backend/tools/sharepoint/tool.py new file mode 100644 index 0000000000..8f5c175418 --- /dev/null +++ b/src/backend/tools/sharepoint/tool.py @@ -0,0 +1,167 @@ +from typing import Any + +import requests + +from backend.config.settings import Settings +from backend.database_models.database import get_session +from backend.schemas.context import Context +from backend.schemas.tool import ToolCategory, ToolDefinition +from backend.services.logger.utils import LoggerFactory +from backend.tools.base import BaseTool, ToolAuthException +from backend.tools.sharepoint.auth import SharepointAuth +from backend.tools.sharepoint.constants import SEARCH_LIMIT, SHAREPOINT_TOOL_ID +from backend.tools.sharepoint.utils import serialize_file_contents, serialize_metadata + +logger = LoggerFactory().get_logger() + + +class SharepointTool(BaseTool): + ID = SHAREPOINT_TOOL_ID + SHAREPOINT_TENANT_ID = Settings().get('tools.sharepoint.tenant_id') + SHAREPOINT_CLIENT_ID = Settings().get('tools.sharepoint.client_id') + SHAREPOINT_CLIENT_SECRET = Settings().get('tools.sharepoint.client_secret') + + BASE_URL = "https://graph.microsoft.com/v1.0" + SEARCH_ENTITY_TYPES = ["driveItem"] + DRIVE_ITEM_DATA_TYPE = "#microsoft.graph.driveItem" + + @classmethod + def is_available(cls) -> bool: + return all([ + cls.SHAREPOINT_TENANT_ID, + cls.SHAREPOINT_CLIENT_ID, + cls.SHAREPOINT_CLIENT_SECRET, + ]) + + + @classmethod + def get_tool_definition(cls) -> ToolDefinition: + return ToolDefinition( + name=cls.ID, + display_name="Sharepoint", + implementation=cls, + parameter_definitions={ + "query": { + "description": "Query to search Sharepoint documents with.", + "type": "str", + "required": True, + } + }, + is_visible=True, + is_available=SharepointTool.is_available(), + auth_implementation=SharepointAuth, + should_return_token=True, + error_message=cls.generate_error_message(), + category=ToolCategory.DataLoader, + description="Returns a list of relevant document snippets from the user's Sharepoint.", + ) # type: ignore + + def _prepare_auth(self, user_id: str) -> None: + sharepoint_auth = SharepointAuth() + session = next(get_session()) + if sharepoint_auth.is_auth_required(session, user_id=user_id): + session.close() + raise ToolAuthException( + "Sharepoint Tool auth Error: Agent creator credentials need to re-authenticate", + SHAREPOINT_TOOL_ID, + ) + + access_token = sharepoint_auth.get_token(session, user_id) + self.headers = { + "Authorization": f"Bearer {access_token}" + } + + def search(self, query: str) -> list[dict]: + request = { + "entityTypes": self.SEARCH_ENTITY_TYPES, + "query": { + "queryString": query, + "size": SEARCH_LIMIT, + }, + } + + error_message = "Error while searching with Sharepoint Tool" + try: + response = requests.post( + f"{self.BASE_URL}/search/query", + headers=self.headers, + json={"requests": [request]}, + ) + body = response.json() + except Exception as exc: + logger.error(event=f"[Sharepoint] Search error: {exc}") + raise Exception(error_message) from exc + if not response.ok: + error = body.get("error", {}) + error_code = error.get("code") + error_description = error.get("message") + error_message = f"{error_message}: {error_code}. {error_description}" + logger.error(event=f"[Sharepoint] Search error: {error_message}") + raise Exception(error_message) + + if not body.get("value"): + return [] + + return body["value"][0].get("hitsContainers", []) + + def get_drive_item_content(self, parent_drive_id: str, resource_id: str) -> bytes|None: + response = requests.get( + f"{self.BASE_URL}/drives/{parent_drive_id}/items/{resource_id}/content", + headers=self.headers, + ) + + # Fail gracefully when retrieving content + if not response.ok: + return None + + return response.content + + def collect_items(self, hits: list[dict]) -> list: + # Gather data + drive_items = [] + for hit in hits: + if hit["resource"]["@odata.type"] == self.DRIVE_ITEM_DATA_TYPE: + parent_drive_id = hit["resource"]["parentReference"]["driveId"] + resource_id = hit["resource"]["id"] + drive_item = self.get_drive_item_content( + parent_drive_id, resource_id + ) + + if drive_item: + drive_items.append((hit, drive_item)) + + return drive_items + + + async def call( + self, parameters: dict, ctx: Context, **kwargs: Any, + ) -> list[dict[str, Any]]: + user_id = str(kwargs.get("user_id", "")) + self._prepare_auth(user_id) + query = parameters.get("query", "").replace("'", "\\'") + search_response = self.search(query) + + hits = [] + for hit_container in search_response: + hits.extend(hit_container.get("hits", [])) + + drive_items = self.collect_items(hits) + + # Serialize results + results = [] + for hit, content in drive_items: + result = {} + if (resource := hit.get("resource")) is not None: + result.update(**serialize_metadata(resource)) + + content = serialize_file_contents(content, result.get("name", "")) + result.update({ + "text": content, + }) + results.append(result) + + if not results: + logger.info(event="[Sharepoint] No documents found.") + return self.get_no_results_error() + + return results diff --git a/src/backend/tools/sharepoint/utils.py b/src/backend/tools/sharepoint/utils.py new file mode 100644 index 0000000000..9d1d1e38fc --- /dev/null +++ b/src/backend/tools/sharepoint/utils.py @@ -0,0 +1,75 @@ +from backend.services.file import ( + CALENDAR_EXTENSION, + CSV_EXTENSION, + DOCX_EXTENSION, + EXCEL_EXTENSION, + EXCEL_OLD_EXTENSION, + JSON_EXTENSION, + MARKDOWN_EXTENSION, + PARQUET_EXTENSION, + PDF_EXTENSION, + TEXT_EXTENSION, + TSV_EXTENSION, + get_file_extension, + read_docx, + read_excel, + read_parquet, +) +from backend.services.utils import read_pdf + + +def serialize_metadata(resource: dict) -> dict: + data = {} + + # Only return primitive types, Coral cannot parse arrays/sub-dictionaries + stripped_resource = { + key: str(value) + for key, value in resource.items() + if isinstance(value, (str, int, bool)) + } + data.update({**stripped_resource}) + + if "name" in resource: + data["title"] = resource["name"] + + if "webUrl" in resource: + data["url"] = resource["webUrl"] + + return data + + +def serialize_file_contents(file_contents: bytes, filename: str) -> str: + """ + Reads the file contents based on the file extension + + Args: + file_contents(bytes): Contents of the file + filename (str): Name of the file + + Returns: + str: The file contents + + Raises: + ValueError: If the file extension is not supported + """ + file_extension = get_file_extension(filename) + + if file_extension == PDF_EXTENSION: + return read_pdf(file_contents) + elif file_extension == DOCX_EXTENSION: + return read_docx(file_contents) + elif file_extension == PARQUET_EXTENSION: + return read_parquet(file_contents) + elif file_extension in [ + TEXT_EXTENSION, + MARKDOWN_EXTENSION, + CSV_EXTENSION, + TSV_EXTENSION, + JSON_EXTENSION, + CALENDAR_EXTENSION + ]: + return file_contents.decode("utf-8") + elif file_extension in [EXCEL_EXTENSION, EXCEL_OLD_EXTENSION]: + return read_excel(file_contents) + + raise ValueError(f"File extension {file_extension} is not supported") diff --git a/src/backend/tools/slack/tool.py b/src/backend/tools/slack/tool.py index 9fde976d0a..eb8b0e1f87 100644 --- a/src/backend/tools/slack/tool.py +++ b/src/backend/tools/slack/tool.py @@ -1,7 +1,8 @@ -from typing import Any, Dict, List +from typing import Any from backend.config.settings import Settings from backend.crud import tool_auth as tool_auth_crud +from backend.schemas.context import Context from backend.schemas.tool import ToolCategory, ToolDefinition from backend.services.logger.utils import LoggerFactory from backend.tools.base import BaseTool @@ -45,7 +46,7 @@ def get_tool_definition(cls) -> ToolDefinition: error_message=cls.generate_error_message(), category=ToolCategory.DataLoader, description="Returns a list of relevant document snippets from slack.", - ) + ) # type: ignore @classmethod def _handle_tool_specific_errors(cls, error: Exception, **kwargs: Any) -> None: @@ -63,7 +64,9 @@ def _handle_tool_specific_errors(cls, error: Exception, **kwargs: Any) -> None: ) raise Exception(message) - async def call(self, parameters: dict, ctx: Any, **kwargs: Any) -> List[Dict[str, Any]]: + async def call( + self, parameters: dict, ctx: Context, **kwargs: Any, + ) -> list[dict[str, Any]]: user_id = kwargs.get("user_id", "") query = parameters.get("query", "") diff --git a/src/backend/tools/tavily_search.py b/src/backend/tools/tavily_search.py index 24ef7f3f94..d52ce546f5 100644 --- a/src/backend/tools/tavily_search.py +++ b/src/backend/tools/tavily_search.py @@ -1,9 +1,9 @@ -from typing import Any, Dict, List +from typing import Any from tavily import TavilyClient from backend.config.settings import Settings -from backend.database_models.database import DBSessionDep +from backend.schemas.context import Context from backend.schemas.tool import ToolCategory, ToolDefinition from backend.tools.base import BaseTool, ToolArgument @@ -37,11 +37,11 @@ def get_tool_definition(cls) -> ToolDefinition: error_message=cls.generate_error_message(), category=ToolCategory.WebSearch, description="Returns a list of relevant document snippets for a textual query retrieved from the internet.", - ) + ) # type: ignore async def call( - self, parameters: dict, ctx: Any, session: DBSessionDep, **kwargs: Any - ) -> List[Dict[str, Any]]: + self, parameters: dict, ctx: Context, **kwargs: Any, + ) -> list[dict[str, Any]]: logger = ctx.get_logger() # Gather search parameters query = parameters.get("query", "") diff --git a/src/backend/tools/web_scrape.py b/src/backend/tools/web_scrape.py index c82ae74c9a..288b8347fc 100644 --- a/src/backend/tools/web_scrape.py +++ b/src/backend/tools/web_scrape.py @@ -1,8 +1,9 @@ -from typing import Any, Dict, List +from typing import Any import aiohttp from bs4 import BeautifulSoup +from backend.schemas.context import Context from backend.schemas.tool import ToolCategory, ToolDefinition from backend.services.logger.utils import LoggerFactory from backend.services.utils import read_pdf @@ -42,12 +43,12 @@ def get_tool_definition(cls) -> ToolDefinition: error_message=cls.generate_error_message(), category=ToolCategory.DataLoader, description="Scrape and returns the textual contents of a webpage as a list of passages for a given url.", - ) + ) # type: ignore async def call( - self, parameters: dict, ctx: Any, **kwargs: Any - ) -> List[Dict[str, Any]]: - url = parameters.get("url") + self, parameters: dict, ctx: Context, **kwargs: Any + ) -> list[dict[str, Any]]: + url = parameters.get("url", "") async with aiohttp.ClientSession(timeout=ASYNC_TIMEOUT) as session: try: @@ -75,7 +76,7 @@ async def call( }] async def handle_response(self, response: aiohttp.ClientResponse, url: str): - content_type = response.headers.get("content-type") + content_type = response.headers.get("content-type", "") results = [] # If URL is a PDF, read contents using helper function diff --git a/src/community/tools/clinicaltrials.py b/src/community/tools/clinicaltrials.py index e9c6c4fd16..1d3bafdde9 100644 --- a/src/community/tools/clinicaltrials.py +++ b/src/community/tools/clinicaltrials.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List +from typing import Any import requests @@ -58,8 +58,8 @@ def get_tool_definition(cls) -> ToolDefinition: ) async def call( - self, parameters: Dict[str, Any], n_max_studies: int = 10, **kwargs - ) -> List[Dict[str, Any]]: + self, parameters: dict[str, Any], **kwargs, + ) -> list[dict[str, Any]]: query_params = {"sort": "LastUpdatePostDate"} if condition := parameters.get("condition", ""): query_params["query.cond"] = condition @@ -69,7 +69,7 @@ async def call( query_params["query.intr"] = intervention if parameters.get("is_recruiting"): query_params["filter.overallStatus"] = "RECRUITING" - query_params["pageSize"] = n_max_studies + query_params["pageSize"] = kwargs.get("n_max_studies", 10) try: response = requests.get(self._url, params=query_params) @@ -85,7 +85,7 @@ async def call( def _parse_response( self, response: requests.Response, location: str, intervention: str - ) -> List[Dict[str, Any]]: + ) -> list[dict[str, Any]]: data = response.json() return [ self._parse_study(study, location, intervention) @@ -93,8 +93,8 @@ def _parse_response( ] def _parse_study( - self, study: Dict[str, Any], location: str, intervention: str - ) -> Dict[str, Any]: + self, study: dict[str, Any], location: str, intervention: str + ) -> dict[str, Any]: """Parse individual study data.""" id_module = study["protocolSection"].get("identificationModule", {}) description_module = study["protocolSection"].get("descriptionModule", {}) @@ -126,8 +126,8 @@ def _parse_study( } def _filter_results( - self, results: List[Dict[str, Any]], target: str, fields: List[str] - ) -> List[Dict[str, Any]]: + self, results: list[dict[str, Any]], target: str, fields: list[str] + ) -> list[dict[str, Any]]: """Keeps a result if any of the specified fields contain the target value. Only the specified fields are returned in the result. If the query does not specify a target value, all results are retained. diff --git a/src/community/tools/llama_index.py b/src/community/tools/llama_index.py index d2dfa2f601..56e83d522e 100644 --- a/src/community/tools/llama_index.py +++ b/src/community/tools/llama_index.py @@ -7,6 +7,7 @@ import backend.crud.file as file_crud from backend.config import Settings +from backend.schemas.context import Context from backend.schemas.tool import ToolCategory, ToolDefinition from backend.tools.base import BaseTool @@ -74,7 +75,7 @@ def get_tool_definition(cls) -> ToolDefinition: ) async def call( - self, parameters: dict, ctx: Any, **kwargs: Any + self, parameters: dict, ctx: Context, **kwargs: Any ) -> List[Dict[str, Any]]: query = parameters.get("query") files = parameters.get("files") diff --git a/src/community/tools/wolfram.py b/src/community/tools/wolfram.py index dc4c27e22a..3383d3921b 100644 --- a/src/community/tools/wolfram.py +++ b/src/community/tools/wolfram.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List +from typing import Any from langchain_community.utilities.wolfram_alpha import WolframAlphaAPIWrapper @@ -39,7 +39,7 @@ def get_tool_definition(cls) -> ToolDefinition: description="Evaluate arithmetic expressions using Wolfram Alpha.", ) - async def call(self, parameters: dict, **kwargs: Any) -> List[Dict[str, Any]]: + async def call(self, parameters: dict, **kwargs: Any) -> list[dict[str, Any]]: to_evaluate = parameters.get("expression", "") try: result = self.tool.run(to_evaluate) @@ -49,4 +49,4 @@ async def call(self, parameters: dict, **kwargs: Any) -> List[Dict[str, Any]]: if not result: return self.get_no_results_error() - return {"result": result, "text": result} + return [{"result": result, "text": result}] diff --git a/src/interfaces/assistants_web/src/app/(main)/settings/Settings.tsx b/src/interfaces/assistants_web/src/app/(main)/settings/Settings.tsx index b2acc46172..595decd146 100644 --- a/src/interfaces/assistants_web/src/app/(main)/settings/Settings.tsx +++ b/src/interfaces/assistants_web/src/app/(main)/settings/Settings.tsx @@ -14,7 +14,7 @@ import { Tabs, Text, } from '@/components/UI'; -import { TOOL_GITHUB_ID, TOOL_GMAIL_ID, TOOL_SLACK_ID } from '@/constants'; +import { TOOL_GITHUB_ID, TOOL_GMAIL_ID, TOOL_SHAREPOINT_ID, TOOL_SLACK_ID } from '@/constants'; import { useDeleteAuthTool, useListTools, useNotify } from '@/hooks'; import { cn, getToolAuthUrl } from '@/utils'; @@ -25,7 +25,7 @@ const tabs: { key: string; icon: IconName; label: string }[] = [ { key: 'profile', icon: 'profile', label: 'Profile' }, ]; -const Settings = () => { +export const Settings: React.FC = () => { const [selectedTabIndex, setSelectedTabIndex] = useState(0); return ( @@ -83,6 +83,7 @@ const Connections = () => ( + ); @@ -234,4 +235,12 @@ const GithubConnection = () => ( /> ); -export { Settings }; +const SharepointConnection = () => ( + +); diff --git a/src/interfaces/assistants_web/src/assets/icons/Sharepoint.tsx b/src/interfaces/assistants_web/src/assets/icons/Sharepoint.tsx new file mode 100644 index 0000000000..a958ea59f0 --- /dev/null +++ b/src/interfaces/assistants_web/src/assets/icons/Sharepoint.tsx @@ -0,0 +1,55 @@ +import * as React from 'react'; +import { SVGProps } from 'react'; + +import { cn } from '@/utils'; + +export const Sharepoint: React.FC> = ({ className, ...props }) => ( + + + + + + + + + + + + + + + + +); diff --git a/src/interfaces/assistants_web/src/assets/icons/index.ts b/src/interfaces/assistants_web/src/assets/icons/index.ts index 33584159a3..acd5354184 100644 --- a/src/interfaces/assistants_web/src/assets/icons/index.ts +++ b/src/interfaces/assistants_web/src/assets/icons/index.ts @@ -60,3 +60,4 @@ export * from './Warning'; export * from './Web'; export * from './Slack'; export * from './Gmail'; +export * from './Sharepoint'; diff --git a/src/interfaces/assistants_web/src/components/UI/Icon.tsx b/src/interfaces/assistants_web/src/components/UI/Icon.tsx index 1a00622dc8..2afee3da07 100644 --- a/src/interfaces/assistants_web/src/components/UI/Icon.tsx +++ b/src/interfaces/assistants_web/src/components/UI/Icon.tsx @@ -47,6 +47,7 @@ import { Search, Setttings, Share, + Sharepoint, Show, SignOut, Slack, @@ -130,6 +131,7 @@ export const IconList = [ 'slack', 'gmail', 'github', + 'sharepoint', ] as const; export type IconName = (typeof IconList)[number]; @@ -489,6 +491,11 @@ const getIcon = (name: IconName, kind: IconKind): React.ReactNode => { ), + ['sharepoint']: ( + + + + ), ['hot-keys']: ( diff --git a/src/interfaces/assistants_web/src/constants/tools.ts b/src/interfaces/assistants_web/src/constants/tools.ts index 029433bdd2..e4bb316ead 100644 --- a/src/interfaces/assistants_web/src/constants/tools.ts +++ b/src/interfaces/assistants_web/src/constants/tools.ts @@ -15,6 +15,7 @@ export const TOOL_GOOGLE_DRIVE_ID = 'google_drive'; export const TOOL_SLACK_ID = 'slack'; export const TOOL_GMAIL_ID = 'gmail'; export const TOOL_GITHUB_ID = 'github'; +export const TOOL_SHAREPOINT_ID = 'sharepoint'; export const BACKGROUND_TOOLS = [TOOL_SEARCH_FILE_ID, TOOL_READ_DOCUMENT_ID]; @@ -32,4 +33,5 @@ export const TOOL_ID_TO_DISPLAY_INFO: { [id: string]: { icon: IconName } } = { [TOOL_SLACK_ID]: { icon: 'slack' }, [TOOL_GMAIL_ID]: { icon: 'gmail' }, [TOOL_GITHUB_ID]: { icon: 'github' }, + [TOOL_SHAREPOINT_ID]: { icon: 'sharepoint' }, };