diff --git a/Makefile b/Makefile index 069b3a16..4854e506 100644 --- a/Makefile +++ b/Makefile @@ -28,15 +28,22 @@ check: ## Run code quality tools. .PHONY: test test: ## Test the code with pytest @echo "🚀 Testing code: Running pytest" - @cd arcade && poetry run pytest -v --cov --cov-config=pyproject.toml --cov-report=xml + @cd arcade && poetry run pytest -W ignore -v --cov --cov-config=pyproject.toml --cov-report=xml .PHONY: test-toolkits test-toolkits: ## Iterate over all toolkits and run pytest on each one @echo "🚀 Testing code in toolkits: Running pytest" @for dir in toolkits/*/ ; do \ - (cd $$dir && poetry run pytest -v --cov --cov-config=pyproject.toml --cov-report=xml || exit 1); \ + (cd $$dir && poetry run pytest -W ignore -v --cov --cov-config=pyproject.toml --cov-report=xml || exit 1); \ done +.PHONY: coverage +coverage: ## Generate coverage report + @echo "coverage report" + @cd arcade && coverage report + @echo "Generating coverage report" + @cd arcade && coverage html + .PHONY: set-version set-version: ## Set the version in the pyproject.toml file @echo "🚀 Setting version in pyproject.toml" diff --git a/arcade/arcade/cli/display.py b/arcade/arcade/cli/display.py new file mode 100644 index 00000000..5cb55b81 --- /dev/null +++ b/arcade/arcade/cli/display.py @@ -0,0 +1,263 @@ +from typing import TYPE_CHECKING, Any + +from rich.console import Console +from rich.panel import Panel +from rich.table import Table +from rich.text import Text + +from arcade.core.config_model import Config +from arcade.core.schema import ToolDefinition + +if TYPE_CHECKING: + from arcade.sdk.eval.eval import EvaluationResult +console = Console() + + +def display_tools_table(tools: list[ToolDefinition]) -> None: + """ + Display a table of tools with their name, description, package, and version. + """ + table = Table(show_header=True, header_style="bold magenta") + table.add_column("Name") + table.add_column("Description") + table.add_column("Package") + table.add_column("Version") + + for tool in sorted(tools, key=lambda x: x.toolkit.name): + table.add_row( + str(tool.get_fully_qualified_name()), + tool.description.split("\n")[0] if tool.description else "", + tool.toolkit.name, + tool.toolkit.version, + ) + console.print(table) + + +def display_tool_details(tool: ToolDefinition) -> None: + """ + Display detailed information about a specific tool using multiple panels. + """ + # Description Panel + description_panel = Panel( + tool.description or "No description available.", + title=f"Tool: {tool.name}", + border_style="cyan", + ) + + # Inputs Panel + inputs = tool.inputs.parameters + if inputs: + inputs_table = Table(show_header=True, header_style="bold green") + inputs_table.add_column("Name", style="cyan") + inputs_table.add_column("Type", style="magenta") + inputs_table.add_column("Required", style="yellow") + inputs_table.add_column("Description", style="white") + inputs_table.add_column("Default", style="blue") + for param in inputs: + # Since InputParameter does not have a default field, we use "N/A" + default_value = "N/A" + if param.value_schema.enum: + default_value = f"One of {param.value_schema.enum}" + inputs_table.add_row( + param.name, + param.value_schema.val_type, + str(param.required), + param.description or "", + default_value, + ) + inputs_panel = Panel( + inputs_table, + title="Input Parameters", + border_style="green", + ) + else: + inputs_panel = Panel( + "No input parameters.", + title="Input Parameters", + border_style="green", + ) + + # Output Panel + output = tool.output + if output: + output_description = output.description or "No description available." + output_types = ", ".join(output.available_modes) + output_val_type = output.value_schema.val_type if output.value_schema else "N/A" + output_details = Text.assemble( + ("Description: ", "bold"), + (output_description, ""), + "\n", + ("Available Modes: ", "bold"), + (output_types, ""), + "\n", + ("Value Type: ", "bold"), + (output_val_type, ""), + ) + output_panel = Panel( + output_details, + title="Expected Output", + border_style="blue", + ) + else: + output_panel = Panel( + "No output information available.", + title="Expected Output", + border_style="blue", + ) + + # Combine all panels vertically + console.print(description_panel) + console.print(inputs_panel) + console.print(output_panel) + + +def display_tool_messages(tool_messages: list[dict]) -> None: + for message in tool_messages: + if message["role"] == "assistant": + for tool_call in message.get("tool_calls", []): + console.print( + f"[bright_black][bold]Called tool '{tool_call['function']['name']}'[/bold]\n[bold]Parameters:[/bold]{tool_call['function']['arguments']}[/bright_black]" + ) + elif message["role"] == "tool": + console.print( + f"[bright_black][bold]'{message['name']}' tool returned:[/bold]{message['content']}[/bright_black]" + ) + + +def display_eval_results(results: list[list[dict[str, Any]]], show_details: bool = False) -> None: + """ + Display evaluation results in a format inspired by pytest's output. + + Args: + results: List of dictionaries containing evaluation results for each model. + show_details: Whether to show detailed results for each case. + """ + total_passed = 0 + total_failed = 0 + total_warned = 0 + total_cases = 0 + + for eval_suite in results: + for model_results in eval_suite: + model = model_results.get("model", "Unknown Model") + rubric = model_results.get("rubric", "Unknown Rubric") + cases = model_results.get("cases", []) + total_cases += len(cases) + + console.print(f"[bold]Model:[/bold] [bold magenta]{model}[/bold magenta]") + if show_details: + console.print(f"[bold magenta]{rubric}[/bold magenta]") + + for case in cases: + evaluation = case["evaluation"] + status = ( + "[green]PASSED[/green]" + if evaluation.passed + else "[yellow]WARNED[/yellow]" + if evaluation.warning + else "[red]FAILED[/red]" + ) + if evaluation.passed: + total_passed += 1 + elif evaluation.warning: + total_warned += 1 + else: + total_failed += 1 + + # Display one-line summary for each case with score as a percentage + score_percentage = evaluation.score * 100 + console.print(f"{status} {case['name']} -- Score: {score_percentage:.2f}%") + + if show_details: + # Show detailed information for each case + console.print(f"[bold]User Input:[/bold] {case['input']}\n") + console.print("[bold]Details:[/bold]") + console.print(_format_evaluation(evaluation)) + console.print("-" * 80) + + # Summary + summary = ( + f"[bold]Summary -- [/bold]Total: {total_cases} -- [green]Passed: {total_passed}[/green]" + ) + if total_warned > 0: + summary += f" -- [yellow]Warnings: {total_warned}[/yellow]" + if total_failed > 0: + summary += f" -- [red]Failed: {total_failed}[/red]" + console.print(summary + "\n") + + +def _format_evaluation(evaluation: "EvaluationResult") -> str: + """ + Format evaluation results with color-coded matches and scores. + + Args: + evaluation: An EvaluationResult object containing the evaluation results. + + Returns: + A formatted string representation of the evaluation details. + """ + result_lines = [] + if evaluation.failure_reason: + result_lines.append(f"[bold red]Failure Reason:[/bold red] {evaluation.failure_reason}") + else: + for critic_result in evaluation.results: + match_color = "green" if critic_result["match"] else "red" + field = critic_result["field"] + score = critic_result["score"] + weight = critic_result["weight"] + expected = critic_result["expected"] + actual = critic_result["actual"] + + result_lines.append( + f"[bold]{field}:[/bold] " + f"[{match_color}]Match: {critic_result['match']}" + f"\n Score: {score:.2f}/{weight:.2f}[/{match_color}]" + f"\n Expected: {expected}" + f"\n Actual: {actual}" + ) + return "\n".join(result_lines) + + +def display_arcade_chat_header(config: Config, stream: bool) -> None: + chat_header = Text.assemble( + "\n", + ( + "=== Arcade AI Chat ===", + "bold magenta underline", + ), + "\n", + "\n", + "Chatting with Arcade Engine at ", + ( + config.engine_url, + "bold blue", + ), + ) + if stream: + chat_header.append(" (streaming)") + console.print(chat_header) + + +def display_config_as_table(config) -> None: # type: ignore[no-untyped-def] + """ + Display the configuration details as a table using Rich library. + """ + table = Table(show_header=True, header_style="bold magenta") + table.add_column("Section") + table.add_column("Name") + table.add_column("Value") + + for section_name in config.model_dump(): + section = getattr(config, section_name) + if section: + section = section.dict() + first = True + for name, value in section.items(): + if first: + table.add_row(section_name, name, str(value)) + first = False + else: + table.add_row("", name, str(value)) + table.add_row("", "", "") + + console.print(table) diff --git a/arcade/arcade/cli/main.py b/arcade/arcade/cli/main.py index d6b609d4..98ed1523 100644 --- a/arcade/arcade/cli/main.py +++ b/arcade/arcade/cli/main.py @@ -11,27 +11,32 @@ from openai import OpenAIError from rich.console import Console from rich.markup import escape -from rich.table import Table from rich.text import Text from arcade.cli.authn import LocalAuthCallbackServer, check_existing_login +from arcade.cli.display import ( + display_arcade_chat_header, + display_config_as_table, + display_eval_results, + display_tool_details, + display_tool_messages, + display_tools_table, +) from arcade.cli.launcher import start_servers from arcade.cli.utils import ( OrderCommands, - apply_config_overrides, create_cli_catalog, - display_eval_results, - display_tool_messages, + get_config_with_overrides, get_eval_files, + get_tools_from_engine, handle_chat_interaction, is_authorization_pending, - load_eval_suites, # Import the new function + load_eval_suites, + log_engine_health, validate_and_get_config, wait_for_authorization_completion, ) from arcade.client import Arcade -from arcade.client.errors import EngineNotHealthyError, EngineOfflineError -from arcade.core.config_model import Config cli = typer.Typer( cls=OrderCommands, @@ -44,27 +49,6 @@ console = Console() -def _get_config_with_overrides( - force_tls: bool, - force_no_tls: bool, - host_input: str | None = None, - port_input: int | None = None, -) -> Config: - """ - Get the config with CLI-specific optional overrides applied. - """ - config = validate_and_get_config() - - if not force_tls and not force_no_tls: - tls_input = None - elif force_no_tls: - tls_input = False - else: - tls_input = True - apply_config_overrides(config, host_input, port_input, tls_input) - return config - - @cli.command(help="Log in to Arcade Cloud", rich_help_panel="User") def login( host: str = typer.Option( @@ -136,44 +120,75 @@ def new( @cli.command( - help="Show the installed toolkits", + help="Show the installed toolkits or details of a specific tool", rich_help_panel="Tool Development", ) def show( toolkit: Optional[str] = typer.Option( None, "-t", "--toolkit", help="The toolkit to show the tools of" ), + tool: Optional[str] = typer.Option( + None, "-T", "--tool", help="The specific tool to show details for" + ), + host: Optional[str] = typer.Option( + None, + "-h", + "--host", + help="The Arcade Engine address to send chat requests to.", + ), + port: Optional[int] = typer.Option( + None, + "-p", + "--port", + help="The port of the Arcade Engine.", + ), + force_tls: bool = typer.Option( + False, + "--tls", + help="Whether to force TLS for the connection to the Arcade Engine. If not specified, the connection will use TLS if the engine URL uses a 'https' scheme.", + ), + force_no_tls: bool = typer.Option( + False, + "--no-tls", + help="Whether to disable TLS for the connection to the Arcade Engine.", + ), debug: bool = typer.Option(False, "--debug", "-d", help="Show debug information"), ) -> None: """ - Show the available tools in an actor or toolkit + Show the available toolkits or detailed information about a specific tool. """ try: - catalog = create_cli_catalog(toolkit=toolkit) - - # Create a table with Rich library - table = Table(show_header=True, header_style="bold magenta") - table.add_column("Name") - table.add_column("Description") - table.add_column("Package") - table.add_column("Version") - - tool_names = catalog.get_tool_names() - for tool_name in tool_names: - tool = catalog.get_tool(tool_name) - package = tool.meta.package if tool.meta.package else tool.meta.toolkit - table.add_row(str(tool_name), tool.description, package, tool.version) - - console.print(table) - - # used when debugging a broken package on import. - # `arcade show` is the first command used after - # a toolkit package is created. + if not host: + catalog = create_cli_catalog(toolkit=toolkit) + tools = [t.definition for t in list(catalog)] + else: + tools = get_tools_from_engine(host, port, force_tls, force_no_tls, toolkit) + + if tool: + # Display detailed information for the specified tool + tool_def = next( + ( + t + for t in tools + if t.get_fully_qualified_name().name == tool + or str(t.get_fully_qualified_name()) == tool + ), + None, + ) + if not tool_def: + console.print(f"❌ Tool '{tool}' not found.", style="bold red") + typer.Exit(code=1) + else: + display_tool_details(tool_def) + else: + # Display the list of tools as a table + display_tools_table(tools) + except Exception as e: if debug: raise - error_message = f"❌ Failed to List tools: {escape(str(e))}" + error_message = f"❌ Failed to list tools: {escape(str(e))}" console.print(error_message, style="bold red") @@ -211,7 +226,7 @@ def chat( """ Chat with a language model. """ - config = _get_config_with_overrides(force_tls, force_no_tls, host, port) + config = get_config_with_overrides(force_tls, force_no_tls, host, port) client = Arcade(api_key=config.api.key, base_url=config.engine_url) user_email = config.user.email if config.user else None @@ -321,73 +336,6 @@ def config( raise typer.Exit(code=1) -def display_arcade_chat_header(config: Config, stream: bool) -> None: - chat_header = Text.assemble( - "\n", - ( - "=== Arcade AI Chat ===", - "bold magenta underline", - ), - "\n", - "\n", - "Chatting with Arcade Engine at ", - ( - config.engine_url, - "bold blue", - ), - ) - if stream: - chat_header.append(" (streaming)") - console.print(chat_header) - - -def log_engine_health(client: Arcade) -> None: - try: - client.health.check() - - except EngineNotHealthyError as e: - console.print( - "[bold][yellow]⚠️ Warning: " - + str(e) - + " (" - + "[/yellow]" - + "[red]" - + str(e.status_code) - + "[/red]" - + "[yellow])[/yellow][/bold]" - ) - except EngineOfflineError: - console.print( - "⚠️ Warning: Arcade Engine was unreachable. (Is it running?)", - style="bold yellow", - ) - - -def display_config_as_table(config) -> None: # type: ignore[no-untyped-def] - """ - Display the configuration details as a table using Rich library. - """ - table = Table(show_header=True, header_style="bold magenta") - table.add_column("Section") - table.add_column("Name") - table.add_column("Value") - - for section_name in config.model_dump(): - section = getattr(config, section_name) - if section: - section = section.dict() - first = True - for name, value in section.items(): - if first: - table.add_row(section_name, name, str(value)) - first = False - else: - table.add_row("", name, str(value)) - table.add_row("", "", "") - - console.print(table) - - @cli.command(help="Run tool calling evaluations", rich_help_panel="Tool Development") def evals( directory: str = typer.Argument(".", help="Directory containing evaluation files"), @@ -428,7 +376,7 @@ def evals( Find all files starting with 'eval_' in the given directory, execute any functions decorated with @tool_eval, and display the results. """ - config = _get_config_with_overrides(force_tls, force_no_tls, host, port) + config = get_config_with_overrides(force_tls, force_no_tls, host, port) models_list = models.split(",") # Use 'models_list' to avoid shadowing diff --git a/arcade/arcade/cli/utils.py b/arcade/arcade/cli/utils.py index d2bf72d3..def7ef32 100644 --- a/arcade/arcade/cli/utils.py +++ b/arcade/arcade/cli/utils.py @@ -1,7 +1,7 @@ import importlib.util from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Union +from typing import Callable, Union import typer from openai.resources.chat.completions import ChatCompletionChunk, Stream @@ -13,16 +13,14 @@ from typer.models import Context from arcade.client.client import Arcade -from arcade.client.errors import APITimeoutError +from arcade.client.errors import APITimeoutError, EngineNotHealthyError, EngineOfflineError from arcade.client.schema import AuthResponse from arcade.core.catalog import ToolCatalog from arcade.core.config_model import Config from arcade.core.errors import ToolkitLoadError +from arcade.core.schema import ToolDefinition from arcade.core.toolkit import Toolkit -if TYPE_CHECKING: - from arcade.sdk.eval.eval import EvaluationResult - console = Console() @@ -64,17 +62,37 @@ def create_cli_catalog( return catalog -def display_tool_messages(tool_messages: list[dict]) -> None: - for message in tool_messages: - if message["role"] == "assistant": - for tool_call in message.get("tool_calls", []): - console.print( - f"[bright_black][bold]Called tool '{tool_call['function']['name']}'[/bold]\n[bold]Parameters:[/bold]{tool_call['function']['arguments']}[/bright_black]" - ) - elif message["role"] == "tool": - console.print( - f"[bright_black][bold]'{message['name']}' tool returned:[/bold]{message['content']}[/bright_black]" - ) +def get_config_with_overrides( + force_tls: bool, + force_no_tls: bool, + host_input: str | None = None, + port_input: int | None = None, +) -> Config: + """ + Get the config with CLI-specific optional overrides applied. + """ + config = validate_and_get_config() + + if not force_tls and not force_no_tls: + tls_input = None + elif force_no_tls: + tls_input = False + else: + tls_input = True + apply_config_overrides(config, host_input, port_input, tls_input) + return config + + +def get_tools_from_engine( + host: str, + port: int | None = None, + force_tls: bool = False, + force_no_tls: bool = False, + toolkit: str | None = None, +) -> list[ToolDefinition]: + config = get_config_with_overrides(force_tls, force_no_tls, host, port) + client = Arcade(api_key=config.api.key, base_url=config.engine_url) + return client.tools.list_tools(toolkit=toolkit) def get_tool_messages(choice: dict) -> list[dict]: @@ -199,96 +217,26 @@ def apply_config_overrides( config.engine.tls = tls_input -def display_eval_results(results: list[list[dict[str, Any]]], show_details: bool = False) -> None: - """ - Display evaluation results in a format inspired by pytest's output. +def log_engine_health(client: Arcade) -> None: + try: + client.health.check() - Args: - results: List of dictionaries containing evaluation results for each model. - show_details: Whether to show detailed results for each case. - """ - total_passed = 0 - total_failed = 0 - total_warned = 0 - total_cases = 0 - - for eval_suite in results: - for model_results in eval_suite: - model = model_results.get("model", "Unknown Model") - rubric = model_results.get("rubric", "Unknown Rubric") - cases = model_results.get("cases", []) - total_cases += len(cases) - - console.print(f"[bold]Model:[/bold] [bold magenta]{model}[/bold magenta]") - if show_details: - console.print(f"[bold magenta]{rubric}[/bold magenta]") - - for case in cases: - evaluation = case["evaluation"] - status = ( - "[green]PASSED[/green]" - if evaluation.passed - else "[yellow]WARNED[/yellow]" - if evaluation.warning - else "[red]FAILED[/red]" - ) - if evaluation.passed: - total_passed += 1 - elif evaluation.warning: - total_warned += 1 - else: - total_failed += 1 - - # Display one-line summary for each case - console.print(f"{status} {case['name']} -- Score: {evaluation.score:.2f}") - - if show_details: - # Show detailed information for each case - console.print(f"[bold]User Input:[/bold] {case['input']}\n") - console.print("[bold]Details:[/bold]") - console.print(_format_evaluation(evaluation)) - console.print("-" * 80) - - # Summary - summary = ( - f"[bold]Summary -- [/bold]Total: {total_cases} -- [green]Passed: {total_passed}[/green]" - ) - if total_warned > 0: - summary += f" -- [yellow]Warnings: {total_warned}[/yellow]" - if total_failed > 0: - summary += f" -- [red]Failed: {total_failed}[/red]" - console.print(summary + "\n") - - -def _format_evaluation(evaluation: "EvaluationResult") -> str: - """ - Format evaluation results with color-coded matches and scores. - - Args: - evaluation: An EvaluationResult object containing the evaluation results. - - Returns: - A formatted string representation of the evaluation details. - """ - result_lines = [] - if evaluation.failure_reason: - result_lines.append(f"[bold red]Failure Reason:[/bold red] {evaluation.failure_reason}") - else: - for critic_result in evaluation.results: - match_color = "green" if critic_result["match"] else "red" - field = critic_result["field"] - score = critic_result["score"] - weight = critic_result["weight"] - expected = critic_result["expected"] - actual = critic_result["actual"] - result_lines.append( - f"[bold]{field}:[/bold] " - f"[{match_color}]Match: {critic_result['match']}, " - f"Score: {score:.2f}/{weight:.2f}[/{match_color}]" - f"\n Expected: {expected}" - f"\n Actual: {actual}" - ) - return "\n".join(result_lines) + except EngineNotHealthyError as e: + console.print( + "[bold][yellow]⚠️ Warning: " + + str(e) + + " (" + + "[/yellow]" + + "[red]" + + str(e.status_code) + + "[/red]" + + "[yellow])[/yellow][/bold]" + ) + except EngineOfflineError: + console.print( + "⚠️ Warning: Arcade Engine was unreachable. (Is it running?)", + style="bold yellow", + ) @dataclass diff --git a/arcade/arcade/client/client.py b/arcade/arcade/client/client.py index 123cb8d7..612f6286 100644 --- a/arcade/arcade/client/client.py +++ b/arcade/arcade/client/client.py @@ -166,6 +166,17 @@ def authorize( ) return AuthResponse(**data) + def list_tools(self, toolkit: str | None = None) -> list[ToolDefinition]: + """ + List the tools available for a given toolkit and provider. + """ + data = self._client._execute_request( # type: ignore[attr-defined] + "GET", + f"{self._resource_path}/list", + params={"toolkit": toolkit}, + ) + return [ToolDefinition(**tool) for tool in data] + class HealthResource(BaseResource[ClientT]): """Health check resource.""" @@ -331,6 +342,17 @@ async def authorize( ) return AuthResponse(**data) + async def list_tools(self, toolkit: str | None = None) -> list[ToolDefinition]: + """ + List the tools available for a given toolkit and provider. + """ + data = await self._client._execute_request( # type: ignore[attr-defined] + "GET", + f"{self._resource_path}/list", + params={"toolkit": toolkit}, + ) + return [ToolDefinition(**tool) for tool in data] + class AsyncHealthResource(BaseResource[AsyncArcadeClient]): """Asynchronous Health check resource.""" diff --git a/arcade/arcade/core/catalog.py b/arcade/arcade/core/catalog.py index a4a92397..a9451a4b 100644 --- a/arcade/arcade/core/catalog.py +++ b/arcade/arcade/core/catalog.py @@ -25,6 +25,7 @@ from arcade.core.errors import ToolDefinitionError from arcade.core.schema import ( + TOOL_NAME_SEPARATOR, FullyQualifiedName, InputParameter, OAuth2Requirement, @@ -215,6 +216,45 @@ def find_tool_by_func(self, func: Callable) -> ToolDefinition: return tool.definition raise ValueError(f"Tool {func} not found in the catalog.") + def get_tool_by_name( + self, name: str, version: Optional[str] = None, separator: str = TOOL_NAME_SEPARATOR + ) -> MaterializedTool: + """ + Get a tool from the catalog by name, optionally including the toolkit name. + + Args: + name: The name of the tool, potentially including the toolkit name separated by the `separator`. + version: The version of the toolkit. Defaults to None. + separator: The separator between toolkit and tool names. Defaults to `TOOL_NAME_SEPARATOR`. + + Returns: + MaterializedTool: The matching tool from the catalog. + + Raises: + ValueError: If the tool is not found in the catalog. + """ + if separator in name: + toolkit_name, tool_name = name.split(separator, 1) + fq_name = FullyQualifiedName( + name=tool_name, toolkit_name=toolkit_name, toolkit_version=version + ) + return self.get_tool(fq_name) + else: + # No toolkit name provided, search tools with matching tool name + matching_tools = [ + tool + for fq_name, tool in self._tools.items() + if fq_name.name.lower() == name.lower() + and ( + version is None + or (fq_name.toolkit_version or "").lower() == (version or "").lower() + ) + ] + if matching_tools: + return matching_tools[0] + + raise ValueError(f"Tool {name} not found in the catalog.") + def get_tool(self, name: FullyQualifiedName) -> MaterializedTool: """ Get a tool from the catalog by fully-qualified name and version. diff --git a/arcade/arcade/sdk/eval/__init__.py b/arcade/arcade/sdk/eval/__init__.py index b5a686af..064e74ef 100644 --- a/arcade/arcade/sdk/eval/__init__.py +++ b/arcade/arcade/sdk/eval/__init__.py @@ -1,10 +1,11 @@ -from .critic import BinaryCritic, NumericCritic, SimilarityCritic +from .critic import BinaryCritic, DatetimeCritic, NumericCritic, SimilarityCritic from .eval import EvalRubric, EvalSuite, ExpectedToolCall, tool_eval __all__ = [ "BinaryCritic", "SimilarityCritic", "NumericCritic", + "DatetimeCritic", "EvalRubric", "EvalSuite", "ExpectedToolCall", diff --git a/arcade/arcade/sdk/eval/critic.py b/arcade/arcade/sdk/eval/critic.py index 8546b4eb..dcb66e36 100644 --- a/arcade/arcade/sdk/eval/critic.py +++ b/arcade/arcade/sdk/eval/critic.py @@ -1,7 +1,11 @@ from abc import ABC, abstractmethod from dataclasses import dataclass +from datetime import timedelta from typing import Any, ClassVar +import pytz +from dateutil import parser + from arcade.sdk.error import WeightError @@ -47,6 +51,17 @@ def cast_actual(self, expected: Any, actual: Any) -> Any: Raises: TypeError: If the casting is not possible. """ + # In case both are strings. + if actual == "None": + actual = None + if expected == "None": + expected = None + if expected is None: + # No need to cast; return actual as is + return actual + if actual is None: + # No need to cast; return None + return None expected_type = type(expected) try: return expected_type(actual) @@ -60,14 +75,18 @@ def evaluate(self, expected: Any, actual: Any) -> dict[str, float | bool]: Evaluates whether the expected and actual values are exactly equal after casting. Args: - expected (Any): The expected value. - actual (Any): The actual value to compare, cast to the type of expected. + expected: The expected value. + actual: The actual value to compare, cast to the type of expected. Returns: - dict[str, float | bool]: A dictionary containing the match status and score. + dict: A dictionary containing the match status and score. """ # Cast actual to the type of expected - actual_casted = self.cast_actual(expected, actual) + try: + actual_casted = self.cast_actual(expected, actual) + # TODO log or something better here + except TypeError: + actual_casted = actual match = expected == actual_casted return {"match": match, "score": self.weight if match else 0.0} @@ -187,3 +206,68 @@ def evaluate(self, expected: str, actual: str) -> dict[str, float | bool]: "match": similarity >= self.similarity_threshold, "score": min(similarity * self.weight, self.weight), } + + +@dataclass +@dataclass +class DatetimeCritic(Critic): + """ + A critic that evaluates the closeness of datetime values within a specified tolerance. + + Attributes: + tolerance: Acceptable timedelta between expected and actual datetimes. + max_difference: Maximum timedelta for a partial score. + """ + + critic_field: str + weight: float + tolerance: timedelta = timedelta(seconds=500) + max_difference: timedelta = timedelta(hours=2) + + def evaluate(self, expected: str, actual: str) -> dict[str, float | bool]: + """Evaluates the closeness of datetime values within a specified tolerance.""" + + # Attempt to parse expected and actual datetime strings + try: + expected_dt = parser.parse(expected) + actual_dt = parser.parse(actual) + except (ValueError, TypeError): + # If parsing fails, return score 0 + return {"match": False, "score": 0.0} + + # Handle cases based on presence of tzinfo + if expected_dt.tzinfo is None and actual_dt.tzinfo is None: + # Both datetimes are naive, compare directly + time_diff_seconds = abs((expected_dt - actual_dt).total_seconds()) + elif expected_dt.tzinfo is not None and actual_dt.tzinfo is not None: + # Both datetimes have tzinfo, compare in UTC + expected_utc = expected_dt.astimezone(pytz.utc) + actual_utc = actual_dt.astimezone(pytz.utc) + time_diff_seconds = abs((expected_utc - actual_utc).total_seconds()) + else: + # One datetime has tzinfo and the other doesn't + # Compare naive datetime with the other's naive equivalent + if expected_dt.tzinfo is not None: + expected_naive = expected_dt.replace(tzinfo=None) + time_diff_seconds = abs((expected_naive - actual_dt).total_seconds()) + else: + actual_naive = actual_dt.replace(tzinfo=None) + time_diff_seconds = abs((expected_dt - actual_naive).total_seconds()) + + # Convert tolerances to seconds + tolerance_seconds = self.tolerance.total_seconds() + max_difference_seconds = self.max_difference.total_seconds() + + if time_diff_seconds <= tolerance_seconds: + # Full score if within tolerance + return {"match": True, "score": self.weight} + elif time_diff_seconds >= max_difference_seconds: + # No score if beyond max_difference + return {"match": False, "score": 0.0} + else: + # Partial score based on time difference + ratio = 1 - (time_diff_seconds / max_difference_seconds) + # Ensure ratio is not negative + ratio = max(ratio, 0) + score = self.weight * ratio + return {"match": False, "score": score} diff --git a/arcade/arcade/sdk/eval/eval.py b/arcade/arcade/sdk/eval/eval.py index ab6d5c56..ae51a89c 100644 --- a/arcade/arcade/sdk/eval/eval.py +++ b/arcade/arcade/sdk/eval/eval.py @@ -1,11 +1,12 @@ import asyncio import functools +import inspect import json from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Callable from arcade.core.config_model import Config -from arcade.core.schema import FullyQualifiedName +from arcade.core.schema import TOOL_NAME_SEPARATOR try: import numpy as np @@ -218,7 +219,10 @@ def check_tool_call_quantity_failure(self, actual_count: int) -> bool: expected_count = len(self.expected_tool_calls) return self.rubric.fail_on_tool_call_quantity and expected_count != actual_count - def evaluate(self, actual_tool_calls: list[tuple[str, dict[str, Any]]]) -> EvaluationResult: + def evaluate( + self, + actual_tool_calls: list[tuple[str, dict[str, Any]]], + ) -> EvaluationResult: """ Evaluate the actual tool calls against the expected tool calls and critics. @@ -229,13 +233,13 @@ def evaluate(self, actual_tool_calls: list[tuple[str, dict[str, Any]]]) -> Evalu An EvaluationResult object containing the evaluation results. """ evaluation_result = EvaluationResult() - actual_tools = [tool for tool, _ in actual_tool_calls] + actual_tools = [tool_name for tool_name, _ in actual_tool_calls] actual_count = len(actual_tool_calls) + if self.check_tool_call_quantity_failure(actual_count): evaluation_result.score = 0.0 evaluation_result.passed = False - evaluation_result.warning = False expected_count = len(self.expected_tool_calls) expected_tool_names = ", ".join( tool_call.name for tool_call in self.expected_tool_calls @@ -246,35 +250,27 @@ def evaluate(self, actual_tool_calls: list[tuple[str, dict[str, Any]]]) -> Evalu ) return evaluation_result - # check if no tools should be called and none were called if not self.expected_tool_calls and not actual_tools: evaluation_result.score = 1.0 evaluation_result.passed = True - evaluation_result.warning = False return evaluation_result if self.check_tool_selection_failure(actual_tools): evaluation_result.score = 0.0 evaluation_result.passed = False - evaluation_result.warning = False expected_tools = [tc.name for tc in self.expected_tool_calls] evaluation_result.failure_reason = f"Tool selection mismatch. Expected tools: {expected_tools}, but got: {actual_tools}" return evaluation_result - # if no critics for tool call arguments, then return - # passing score as only tool selection and quantity is checked - if not self.critics or len(self.critics) == 0: + if not self.critics: evaluation_result.score = 1.0 evaluation_result.passed = True - evaluation_result.warning = False - # TODO passing reason should be added return evaluation_result # Create a cost matrix for the assignment problem - cost_matrix = self._create_cost_matrix(actual_tool_calls) + cost_matrix = self._create_cost_matrix(actual_tool_calls, self.expected_tool_calls) - # Use the Linear Sum Assignment (LSA) algorithm to find the optimal assignment - # The algorithm maximizes the total score of the assignment + # Use the Linear Sum Assignment algorithm to find the optimal assignment row_ind, col_ind = linear_sum_assignment(cost_matrix, maximize=True) total_score = 0.0 @@ -283,10 +279,11 @@ def evaluate(self, actual_tool_calls: list[tuple[str, dict[str, Any]]]) -> Evalu for i, j in zip(row_ind, col_ind): if i < len(self.expected_tool_calls) and j < len(actual_tool_calls): expected = self.expected_tool_calls[i] - actual_tool, actual_args = actual_tool_calls[j] + actual_name, actual_args = actual_tool_calls[j] + # Tool selection tool_selection_score = evaluation_result.score_tool_selection( - expected.name, actual_tool, self.rubric.tool_selection_weight + expected.name, actual_name, self.rubric.tool_selection_weight ) total_score += tool_selection_score total_weight += self.rubric.tool_selection_weight @@ -295,32 +292,35 @@ def evaluate(self, actual_tool_calls: list[tuple[str, dict[str, Any]]]) -> Evalu for critic in self.critics: expected_value = expected.args.get(critic.critic_field) actual_value = actual_args.get(critic.critic_field) - if expected_value is not None and actual_value is not None: - try: - result = critic.evaluate(expected_value, actual_value) - total_score += result["score"] - total_weight += critic.weight - evaluation_result.add( - critic.critic_field, - result, - critic.weight, - expected_value, - actual_value, - ) - except Exception as e: - print( - f"Critic evaluation failed for field '{critic.critic_field}': {e}" - ) - # Depending on requirements, you might want to continue or handle differently - continue - - # Compute the final score using the method from EvaluationResult + + try: + result = critic.evaluate(expected_value, actual_value) + total_score += result["score"] + total_weight += critic.weight + evaluation_result.add( + critic.critic_field, + result, + critic.weight, + expected_value, + actual_value, + ) + except Exception as e: + # TODO: log or console + print(f"Critic evaluation failed for field '{critic.critic_field}': {e}") + evaluation_result.add( + critic.critic_field, + {"match": False, "score": 0.0}, + critic.weight, + expected_value, + actual_value, + ) + continue + + # Compute the final score evaluation_result.compute_final_score(total_weight) - # Set the pass/fail status based on the fail_threshold + # Set pass/fail and warning status evaluation_result.passed = evaluation_result.score >= self.rubric.fail_threshold - - # Set the warning status based on the warn_threshold evaluation_result.warning = ( not evaluation_result.passed and evaluation_result.score >= self.rubric.warn_threshold ) @@ -328,103 +328,52 @@ def evaluate(self, actual_tool_calls: list[tuple[str, dict[str, Any]]]) -> Evalu return evaluation_result def _create_cost_matrix( - self, actual_tool_calls: list[tuple[str, dict[str, Any]]] + self, + actual_tool_calls: list[tuple[str, dict[str, Any]]], + expected_tool_calls: list[ExpectedToolCall], ) -> np.ndarray: """ - Create a cost matrix for the Hungarian algorithm. - - This method computes the score for each possible pairing of expected and actual tool calls. - The resulting matrix is used by the Hungarian algorithm to find the optimal assignment. + Create a cost matrix for the assignment problem. Args: - actual_tool_calls: A list of tuples containing the actual tool calls and their arguments. + actual_tool_calls: A list of tuples of actual tool calls. + expected_tool_calls: A list of ExpectedToolCall instances. Returns: A numpy array representing the cost matrix. """ - num_expected = len(self.expected_tool_calls) + num_expected = len(expected_tool_calls) num_actual = len(actual_tool_calls) n = max(num_expected, num_actual) - # Initialize a score matrix with zeros - score_matrix = np.zeros((n, n)) + cost_matrix = np.zeros((n, n)) for i in range(n): for j in range(n): if i < num_expected and j < num_actual: - expected = self.expected_tool_calls[i] - expected_tool = expected.name - expected_args = expected.args - actual_tool, actual_args = actual_tool_calls[j] + expected = expected_tool_calls[i] + actual_name, actual_args = actual_tool_calls[j] score = 0.0 # Tool selection - if compare_tool_name(expected_tool, actual_tool): + if compare_tool_name(expected.name, actual_name): score += self.rubric.tool_selection_weight # Critics evaluation - if self.critics: - for critic in self.critics: - expected_value = expected_args.get(critic.critic_field) - actual_value = actual_args.get(critic.critic_field) - if expected_value is not None and actual_value is not None: - try: - result = critic.evaluate(expected_value, actual_value) - score += result.get("score", 0.0) - except Exception as e: - print( - f"Critic evaluation failed for field '{critic.critic_field}': {e}" - ) - continue - - score_matrix[i, j] = score - else: - # Assign a score of 0 for dummy assignments - score_matrix[i, j] = 0.0 - - return score_matrix - - async def run( - self, client: AsyncArcade, model: str, tool_names: list[FullyQualifiedName] - ) -> dict[str, Any]: - """ - Run the evaluation case asynchronously. - - Args: - client: The AsyncArcade client instance. - model: The model to evaluate. - tool_names: The list of tool names to use for the evaluation. - Returns: - A dictionary containing the evaluation result for the case. - """ - messages = [{"role": "system", "content": self.system_message}] - messages.extend(list(self.additional_messages)) - messages.append({"role": "user", "content": self.user_message}) - - response = await client.chat.completions.create( # type: ignore[call-overload] - model=model, - messages=messages, - tool_choice="auto", - tools=(str(name) for name in tool_names), - user="eval_user", - stream=False, - ) - - predicted_args = get_tool_args(response) - - evaluation = self.evaluate(predicted_args) - - result = { - "name": self.name, - "input": self.user_message, - "expected_tool_calls": [ - {"name": tc.name, "args": tc.args} for tc in self.expected_tool_calls - ], - "predicted_tool_calls": [{"name": tool, "args": args} for tool, args in predicted_args], - "evaluation": evaluation, - } - - return result + for critic in self.critics: # type: ignore[union-attr] + expected_value = expected.args.get(critic.critic_field) + actual_value = actual_args.get(critic.critic_field) + if expected_value is not None and actual_value is not None: + try: + result = critic.evaluate(expected_value, actual_value) + score += result.get("score", 0.0) + except Exception as e: + print( + f"Critic evaluation failed for field '{critic.critic_field}': {e}" + ) + cost_matrix[i, j] = score + + return cost_matrix @dataclass @@ -467,19 +416,19 @@ def add_case( Args: name: The name of the evaluation case. user_message: The user's input message. - system_message: The system message to be sent to the AI model. - expected_tool_calls: A list of expected tool calls. + expected_tool_calls: A list of expected tool calls as tuples of (function, args). critics: List of critics to evaluate the tool arguments. + system_message: The system message to be used. rubric: The evaluation rubric for this case. additional_messages: Optional list of additional messages for context. """ - expected = [ - ExpectedToolCall( - name=str(self.catalog.find_tool_by_func(func).get_fully_qualified_name()), - args=args, - ) - for func, args in expected_tool_calls - ] + expected = [] + for func, args in expected_tool_calls: + # Fill in default arguments here + args_with_defaults = self._fill_args_with_defaults(func, args) + tool_name = str(self.catalog.find_tool_by_func(func).get_fully_qualified_name()) + expected.append(ExpectedToolCall(name=tool_name, args=args_with_defaults)) + case = EvalCase( name=name, system_message=system_message or self.system_message, @@ -491,6 +440,30 @@ def add_case( ) self.cases.append(case) + def _fill_args_with_defaults( + self, func: Callable, provided_args: dict[str, Any] + ) -> dict[str, Any]: + """ + Fill in default arguments for a tool function. + + Args: + func: The tool function. + provided_args: The provided arguments. + + Returns: + A dictionary with default arguments filled in. + """ + sig = inspect.signature(func) + args_with_defaults = {} + for param in sig.parameters.values(): + if param.name in provided_args: + args_with_defaults[param.name] = provided_args[param.name] + elif param.default is not inspect.Parameter.empty: + args_with_defaults[param.name] = param.default + else: + args_with_defaults[param.name] = None # or raise an error + return args_with_defaults + def extend_case( self, name: str, @@ -528,13 +501,12 @@ def extend_case( expected = last_case.expected_tool_calls if expected_tool_calls: - expected = [ - ExpectedToolCall( - name=str(self.catalog.find_tool_by_func(func).get_fully_qualified_name()), - args=args, - ) - for func, args in expected_tool_calls - ] + expected = [] + for func, args in expected_tool_calls: + # Fill in default arguments here + args_with_defaults = self._fill_args_with_defaults(func, args) + tool_name = str(self.catalog.find_tool_by_func(func).get_fully_qualified_name()) + expected.append(ExpectedToolCall(name=tool_name, args=args_with_defaults)) # Create a new case, copying from the last one and updating fields new_case = EvalCase( @@ -550,9 +522,10 @@ def extend_case( async def run(self, client: AsyncArcade, model: str) -> dict[str, Any]: """ - Run the evaluation suite asynchronously. + Run the evaluation suite. Args: + client: The AsyncArcade client instance. model: The model to evaluate. Returns: @@ -565,7 +538,48 @@ async def run(self, client: AsyncArcade, model: str) -> dict[str, Any]: async def sem_task(case: EvalCase) -> dict[str, Any]: async with semaphore: - return await case.run(client, model, tool_names) + # Prepare messages + messages = [{"role": "system", "content": case.system_message}] + messages.extend(case.additional_messages) + messages.append({"role": "user", "content": case.user_message}) + + # Get the model response + response = await client.chat.completions.create( # type: ignore[call-overload] + model=model, + messages=messages, + tool_choice="auto", + tools=(str(name) for name in tool_names), + user="eval_user", + stream=False, + ) + + # Extract and fill default arguments for actual tool calls + predicted_args = get_tool_args(response) + filled_actual_tool_calls = [] + for tool_name, args in predicted_args: + tool = self.catalog.get_tool_by_name(tool_name) + if tool is None: + raise ValueError(f"Tool '{tool_name}' not found in catalog.") + func = tool.tool + args_with_defaults = self._fill_args_with_defaults(func, args) + filled_actual_tool_calls.append((tool_name, args_with_defaults)) + + # Evaluate the case + evaluation = case.evaluate(filled_actual_tool_calls) + + # Prepare the result + result = { + "name": case.name, + "input": case.user_message, + "expected_tool_calls": [ + {"name": tc.name, "args": tc.args} for tc in case.expected_tool_calls + ], + "predicted_tool_calls": [ + {"name": name, "args": args} for name, args in filled_actual_tool_calls + ], + "evaluation": evaluation, + } + return result tasks = [sem_task(case) for case in self.cases] case_results = await asyncio.gather(*tasks) @@ -589,7 +603,7 @@ def get_tool_args(chat_completion: Any) -> list[tuple[str, dict[str, Any]]]: if message.tool_calls: for tool_call in message.tool_calls: tool_args_list.append(( - tool_call.function.name, + normalize_name(tool_call.function.name), json.loads(tool_call.function.arguments), )) return tool_args_list @@ -597,17 +611,31 @@ def get_tool_args(chat_completion: Any) -> list[tuple[str, dict[str, Any]]]: def compare_tool_name(expected: str, actual: str) -> bool: """ - Compare the tool name without penalizing for mismatch in separators - between module names and tool names ex. '-' vs '_' vs '.' vs ' ' + Compare the tool names by replacing all separators with the TOOL_NAME_SEPARATOR + and comparing the normalized names. + + Converts names like 'Google_ListEmails' to 'Google.ListEmails' if + TOOL_NAME_SEPARATOR is '.'. + + Args: + expected: The expected tool name. + actual: The actual tool name. + + Returns: + True if the normalized tool names match, False otherwise. """ - # TODO optimize this - # Remove all separators from both names separators = "-_." - expected_clean = "".join(char for char in expected if char not in separators) - actual_clean = "".join(char for char in actual if char not in separators) + expected_normalized = normalize_name(expected, separators) + actual_normalized = normalize_name(actual, separators) + + return expected_normalized.lower() == actual_normalized.lower() + - # Compare the cleaned names - return expected_clean.lower() == actual_clean.lower() +def normalize_name(name: str, separators: str = "-_.") -> str: + for sep in separators: + if sep != TOOL_NAME_SEPARATOR: + name = name.replace(sep, TOOL_NAME_SEPARATOR) + return name def tool_eval() -> Callable[[Callable], Callable]: diff --git a/arcade/codecov.yaml b/arcade/codecov.yaml index 058cfb76..13bf40a2 100644 --- a/arcade/codecov.yaml +++ b/arcade/codecov.yaml @@ -7,3 +7,5 @@ coverage: default: target: 90% threshold: 0.5% + exclude: + - arcade/cli/** diff --git a/arcade/pyproject.toml b/arcade/pyproject.toml index 2ffb19b0..569161ab 100644 --- a/arcade/pyproject.toml +++ b/arcade/pyproject.toml @@ -30,10 +30,12 @@ uvicorn = {version = "^0.30.0", optional = true} scipy = {version = "^1.14.0", optional = true} numpy = {version = "^2.0.0", optional = true} scikit-learn = {version = "^1.5.0", optional = true} +pytz = {version = "^2024.1", optional = true} +python-dateutil = {version = "^2.8.2", optional = true} [tool.poetry.extras] fastapi = ["fastapi", "uvicorn", "opentelemetry-instrumentation-fastapi", "opentelemetry-exporter-otlp-proto-http", "opentelemetry-exporter-otlp-proto-common"] -evals = ["scipy", "numpy", "scikit-learn"] +evals = ["scipy", "numpy", "scikit-learn", "pytz", "python-dateutil"] [tool.poetry.group.dev.dependencies] pytest = "^8.1.2" @@ -43,6 +45,8 @@ pre-commit = "^3.4.0" tox = "^4.11.1" pytest-asyncio = "^0.23.7" types-toml = "^0.10.8" +types-pytz = "^2024.1" +types-python-dateutil = "^2.8.2" poetry-plugin-export = "^1.7.0" [tool.poetry.scripts] @@ -63,9 +67,11 @@ ignore_missing_imports = "True" [tool.pytest.ini_options] testpaths = ["tests"] -[tool.coverage.report] -skip_empty = true [tool.coverage.run] branch = true source = ["arcade"] +omit = ["arcade/cli/*"] + +[tool.coverage.report] +skip_empty = true diff --git a/arcade/tests/client/test_client.py b/arcade/tests/client/test_client.py index 0b607665..38704fb4 100644 --- a/arcade/tests/client/test_client.py +++ b/arcade/tests/client/test_client.py @@ -373,3 +373,24 @@ async def mock_execute_request(*args, **kwargs): monkeypatch.setattr(AsyncArcade, "_execute_request", mock_execute_request) with pytest.raises(EngineNotHealthyError): await test_async_client.health.check() + + +def test_arcade_tool_list_tools(test_sync_client, mock_response, monkeypatch): + """Test Arcade.tools.list_tools method.""" + data = [TOOL_DEFINITION_DATA] + monkeypatch.setattr(Arcade, "_execute_request", lambda *args, **kwargs: data) + tool_definitions = test_sync_client.tools.list_tools(toolkit="TestToolkit") + assert tool_definitions == [ToolDefinition(**TOOL_DEFINITION_DATA)] + + +@pytest.mark.asyncio +async def test_async_arcade_tool_list_tools(test_async_client, mock_async_response, monkeypatch): + """Test AsyncArcade.tools.list_tools method.""" + data = [TOOL_DEFINITION_DATA] + + async def mock_execute_request(*args, **kwargs): + return data + + monkeypatch.setattr(AsyncArcade, "_execute_request", mock_execute_request) + tool_definitions = await test_async_client.tools.list_tools(toolkit="TestToolkit") + assert tool_definitions == [ToolDefinition(**TOOL_DEFINITION_DATA)] diff --git a/arcade/tests/core/test_catalog.py b/arcade/tests/core/test_catalog.py index 4a894a97..6254ce5f 100644 --- a/arcade/tests/core/test_catalog.py +++ b/arcade/tests/core/test_catalog.py @@ -68,7 +68,6 @@ def test_get_tool(toolkit_version: str | None, expected_tool): name="SampleTool", toolkit_name="SampleToolkit", toolkit_version=toolkit_version ) tool = catalog.get_tool(fq_name) - assert tool.tool == expected_tool @@ -102,3 +101,38 @@ class InvalidTool: assert "Type error encountered while adding tool invalid_tool from mock_module" in str( exc_info.value ) + + +def test_get_tool_by_name(): + catalog = ToolCatalog() + catalog.add_tool(sample_tool, "sample_toolkit") + + tool = catalog.get_tool_by_name("SampleToolkit.SampleTool") + assert tool.tool == sample_tool + assert tool.name == "SampleTool" + assert tool.meta.toolkit == "sample_toolkit" + assert tool.version is None + + with pytest.raises(ValueError): + catalog.get_tool_by_name("nonexistent_toolkit.SampleTool") + + +def test_get_tool_by_name_with_version(): + catalog = ToolCatalog() + catalog.add_tool(sample_tool, "sample_toolkit") + + tool = catalog.get_tool_by_name("SampleToolkit.SampleTool") + assert tool.tool == sample_tool + assert tool.name == "SampleTool" + assert tool.meta.toolkit == "sample_toolkit" + + with pytest.raises(ValueError): + catalog.get_tool_by_name("SampleToolkit.SampleTool", version="2.0.0") + + +def test_get_tool_by_name_with_invalid_version(): + catalog = ToolCatalog() + catalog.add_tool(sample_tool, "SampleToolkit") + + with pytest.raises(ValueError): + catalog.get_tool_by_name("SampleToolkit.SampleTool", version="2.0.0") diff --git a/arcade/tests/sdk/test_eval.py b/arcade/tests/sdk/test_eval.py index f6281aaa..bc9c8113 100644 --- a/arcade/tests/sdk/test_eval.py +++ b/arcade/tests/sdk/test_eval.py @@ -1,8 +1,13 @@ +from datetime import timedelta + import pytest +import pytz +from dateutil import parser from arcade.sdk.error import WeightError from arcade.sdk.eval import ( BinaryCritic, + DatetimeCritic, EvalRubric, ExpectedToolCall, NumericCritic, @@ -255,19 +260,13 @@ def test_eval_case_multiple_critics(): # Test EvalCase with missing expected and actual values in args -@pytest.mark.parametrize( - "expected_args, actual_args, expected_score", - [ - ({"param": "value"}, {}, 1.0), # Missing actual value - ({}, {"param": "value"}, 1.0), # Missing expected value - ({"param": "value"}, {"param": "value"}, 2.0), # Both values present - ], -) -def test_eval_case_missing_values(expected_args, actual_args, expected_score): +def test_eval_case_with_none_values(): """ - Test that when either expected or actual values are missing for a critic, - the critic evaluation is skipped, and the total score is computed accordingly. + Test that when expected or actual values are None, the critic evaluates them appropriately. """ + expected_args = {"param": None} + actual_args = {"param": None} + expected_tool_calls = [ExpectedToolCall(name="ToolA", args=expected_args)] actual_tool_calls = [("ToolA", actual_args)] @@ -284,15 +283,8 @@ def test_eval_case_missing_values(expected_args, actual_args, expected_score): result = case.evaluate(actual_tool_calls) - # If critic is skipped, only tool selection score is counted - # Otherwise, tool selection + critic score - total_weight = 1.0 # At least tool selection weight - if "param" in expected_args and "param" in actual_args: - total_weight += 1.0 # Critic weight - - expected_total_score = expected_score / total_weight - - assert result.score == expected_total_score + # Both values are None, so the critic should return a match + assert result.score == 2.0 / 2.0 # Full score (tool selection + critic score) # Test that WeightError is raised for invalid critic weights @@ -340,3 +332,136 @@ def test_similarity_critic_unsupported_metric(): """ with pytest.raises(ValueError): SimilarityCritic(critic_field="text", weight=1.0, metric="unsupported_metric") + + +# Test DatetimeCritic + + +# Parameterized tests for DatetimeCritic with various datetime formats and default timezones +@pytest.mark.parametrize( + "critic_params, expected, actual, expected_match, expected_score", + [ + # Test with time component and timezone + ( + {"critic_field": "start_datetime", "weight": 1.0}, + "2024-09-26T12:00:00-07:00", + "2024-09-26T12:00:00-07:00", + True, + 1.0, + ), + # Test without time component (dates only) + ( + {"critic_field": "start_datetime", "weight": 1.0}, + "2024-09-26", + "2024-09-26", + True, + 1.0, + ), + # Test with and without timezone (assumes UTC) + ( + {"critic_field": "start_datetime", "weight": 1.0}, + "2024-09-26T12:00:00Z", + "2024-09-26T12:00:00", + True, + 1.0, + ), + # Test naive datetimes + ( + {"critic_field": "start_datetime", "weight": 1.0}, + "2024-09-26T12:00:00", + "2024-09-26T12:00:00", + True, + 1.0, + ), + ], +) +def test_datetime_critic_basic(critic_params, expected, actual, expected_match, expected_score): + """ + Test DatetimeCritic with various datetime formats and default timezones. + """ + critic = DatetimeCritic(**critic_params) + result = critic.evaluate(expected, actual) + assert result["match"] == expected_match + assert result["score"] == expected_score + + +# Parameterized tests for DatetimeCritic's handling of tolerances and max differences +@pytest.mark.parametrize( + "critic_params, expected, actual, expected_match, expected_score_func", + [ + # Test time difference within tolerance + ( + {"critic_field": "start_datetime", "weight": 1.0, "tolerance": timedelta(seconds=60)}, + "2024-09-26T12:00:00", + "2024-09-26T12:00:30", + True, + lambda critic: critic.weight, + ), + # Test time difference outside tolerance but within max_difference + ( + { + "critic_field": "start_datetime", + "weight": 1.0, + "tolerance": timedelta(seconds=60), + "max_difference": timedelta(minutes=5), + }, + "2024-09-26T12:00:00", + "2024-09-26T12:04:00", + False, + lambda critic: critic.weight * (1 - (240 / 300)), + ), + # Test time difference exceeds max_difference + ( + { + "critic_field": "start_datetime", + "weight": 1.0, + "max_difference": timedelta(minutes=5), + }, + "2024-09-26T12:00:00", + "2024-09-26T12:10:00", + False, + lambda critic: 0.0, + ), + ], +) +def test_datetime_critic_tolerances( + critic_params, expected, actual, expected_match, expected_score_func +): + """ + Test DatetimeCritic's handling of tolerances and max differences. + """ + critic = DatetimeCritic(**critic_params) + result = critic.evaluate(expected, actual) + assert result["match"] == expected_match + expected_score = expected_score_func(critic) + assert pytest.approx(result["score"], abs=1e-6) == expected_score + + +def test_datetime_critic_naive_and_timezone_aware(): + """ + Test DatetimeCritic when comparing naive and timezone-aware datetimes. + """ + critic = DatetimeCritic(critic_field="start_datetime", weight=1.0) + expected = "2024-09-26T12:00:00Z" + actual = "2024-09-26T07:00:00" + result = critic.evaluate(expected, actual) + assert result["match"] is False + + # Compute expected score based on time difference + expected_dt = parser.parse(expected) + actual_dt = parser.parse(actual) + if actual_dt.tzinfo is None: + actual_dt = pytz.utc.localize(actual_dt) + if expected_dt.tzinfo is None: + expected_dt = pytz.utc.localize(expected_dt) + + time_diff_seconds = abs((expected_dt - actual_dt).total_seconds()) + if time_diff_seconds <= critic.tolerance.total_seconds(): + expected_score = critic.weight + elif time_diff_seconds >= critic.max_difference.total_seconds(): + expected_score = 0.0 + else: + ratio = 1 - (time_diff_seconds / critic.max_difference.total_seconds()) + expected_score = critic.weight * ratio + + assert pytest.approx(result["score"], abs=1e-6) == expected_score diff --git a/toolkits/github/arcade_github/tests/test_repositories.py b/toolkits/github/arcade_github/tests/test_repositories.py index b484a389..154023d7 100644 --- a/toolkits/github/arcade_github/tests/test_repositories.py +++ b/toolkits/github/arcade_github/tests/test_repositories.py @@ -48,7 +48,7 @@ async def test_error_responses( if status_code == 422: await list_org_repositories(mock_context, "org", repo_type=RepoType.ALL) elif status_code == 301: - await count_stargazers("owner", "repo") + await count_stargazers(mock_context, "owner", "repo") elif status_code == 404: await list_org_repositories(mock_context, "non_existent_org") elif status_code == 503: @@ -66,8 +66,8 @@ async def test_list_repository_activities_invalid_cursor(mock_context, mock_clie @pytest.mark.asyncio -async def test_count_stargazers_success(mock_client): +async def test_count_stargazers_success(mock_context, mock_client): mock_client.get.return_value = Response(200, json={"stargazers_count": 42}) - result = await count_stargazers("owner", "repo") - assert result == "The repository owner/repo has 42 stargazers." + result = await count_stargazers(mock_context, "owner", "repo") + assert result == 42 diff --git a/toolkits/github/evals/eval_github_pull_requests.py b/toolkits/github/evals/eval_github_pull_requests.py index 3ef2ef3a..25fe6579 100644 --- a/toolkits/github/evals/eval_github_pull_requests.py +++ b/toolkits/github/evals/eval_github_pull_requests.py @@ -1,8 +1,12 @@ import arcade_github -from arcade_github.tools.models import DiffSide, ReviewCommentSubjectType # Add these imports +from arcade_github.tools.models import ( + DiffSide, + ReviewCommentSubjectType, + SortDirection, +) from arcade_github.tools.pull_requests import ( create_reply_for_review_comment, - create_review_comment, # Add this import + create_review_comment, get_pull_request, list_pull_request_commits, list_pull_requests, @@ -169,7 +173,7 @@ def github_pull_requests_eval_suite() -> EvalSuite: "repo": "test", "pull_number": 72, "sort": "updated", - "direction": "asc", + "direction": SortDirection.ASC, }, ) ], diff --git a/toolkits/github/evals/eval_github_repositories.py b/toolkits/github/evals/eval_github_repositories.py index 7bcc0ba5..193fba73 100644 --- a/toolkits/github/evals/eval_github_repositories.py +++ b/toolkits/github/evals/eval_github_repositories.py @@ -1,4 +1,5 @@ import arcade_github +from arcade_github.tools.models import SortDirection from arcade_github.tools.repositories import ( count_stargazers, get_repository, @@ -66,7 +67,7 @@ def github_repositories_eval_suite() -> EvalSuite: "org": "ArcadeAI", "repo_type": "all", "sort": "created", - "sort_direction": "desc", + "sort_direction": SortDirection.DESC, }, ) ], @@ -108,7 +109,7 @@ def github_repositories_eval_suite() -> EvalSuite: { "owner": "ArcadeAI", "repo": "test", - "direction": "desc", + "direction": SortDirection.DESC, "per_page": 30, "actor": "TestUser", "time_period": "month", @@ -138,7 +139,7 @@ def github_repositories_eval_suite() -> EvalSuite: "owner": "ArcadeAI", "repo": "test", "sort": "created", - "direction": "desc", + "direction": SortDirection.DESC, "per_page": 30, "page": 1, "include_extra_data": False, diff --git a/toolkits/google/arcade_google/tools/calendar.py b/toolkits/google/arcade_google/tools/calendar.py index 8d708084..b0f37994 100644 --- a/toolkits/google/arcade_google/tools/calendar.py +++ b/toolkits/google/arcade_google/tools/calendar.py @@ -1,17 +1,16 @@ from datetime import datetime, timedelta from typing import Annotated -from zoneinfo import ZoneInfo from google.oauth2.credentials import Credentials from googleapiclient.discovery import build from googleapiclient.errors import HttpError -from arcade.core.errors import RetryableToolError, ToolExecutionError +from arcade.core.errors import RetryableToolError from arcade.core.schema import ToolContext from arcade.sdk import tool from arcade.sdk.auth import Google -from arcade_google.tools.models import Day, EventVisibility, SendUpdatesOptions, TimeSlot -from arcade_google.tools.utils import _update_datetime +from arcade_google.tools.models import EventVisibility, SendUpdatesOptions +from arcade_google.tools.utils import parse_datetime @tool( @@ -25,97 +24,87 @@ async def create_event( context: ToolContext, summary: Annotated[str, "The title of the event"], - start_date: Annotated[Day, "The day that the event starts"], - start_time: Annotated[TimeSlot, "The time of the day that the event starts"], - end_date: Annotated[Day, "The day that the event ends"], - end_time: Annotated[TimeSlot, "The time of the day that the event ends"], + start_datetime: Annotated[ + str, + "The datetime when the event starts in ISO 8601 format, e.g., '2024-12-31T15:30:00'.", + ], + end_datetime: Annotated[ + str, + "The datetime when the event ends in ISO 8601 format, e.g., '2024-12-31T17:30:00'.", + ], calendar_id: Annotated[ - str, "The ID of the calendar to create the event in, usually 'primary'" + str, "The ID of the calendar to create the event in, usually 'primary'." ] = "primary", description: Annotated[str | None, "The description of the event"] = None, location: Annotated[str | None, "The location of the event"] = None, visibility: Annotated[EventVisibility, "The visibility of the event"] = EventVisibility.DEFAULT, attendee_emails: Annotated[ list[str] | None, - "The list of attendee emails. Must be valid email addresses e.g., username@domain.com", + "The list of attendee emails. Must be valid email addresses e.g., username@domain.com.", ] = None, ) -> Annotated[dict, "A dictionary containing the created event details"]: """Create a new event/meeting/sync/meetup in the specified calendar.""" service = build("calendar", "v3", credentials=Credentials(context.authorization.token)) - try: - # Get the calendar's time zone - calendar = service.calendars().get(calendarId=calendar_id).execute() - time_zone = calendar["timeZone"] - - # Convert enum values to datetime objects - start_datetime = datetime.combine(start_date.to_date(time_zone), start_time.to_time()) - end_datetime = datetime.combine(end_date.to_date(time_zone), end_time.to_time()) - - event = { - "summary": summary, - "description": description, - "location": location, - "start": {"dateTime": start_datetime.isoformat(), "timeZone": time_zone}, - "end": {"dateTime": end_datetime.isoformat(), "timeZone": time_zone}, - "visibility": visibility.value, - } + # Get the calendar's time zone + calendar = service.calendars().get(calendarId=calendar_id).execute() + time_zone = calendar["timeZone"] - if attendee_emails: - event["attendees"] = [{"email": email} for email in attendee_emails] + # Parse datetime strings + start_dt = parse_datetime(start_datetime, time_zone) + end_dt = parse_datetime(end_datetime, time_zone) + + event = { + "summary": summary, + "description": description, + "location": location, + "start": {"dateTime": start_dt.isoformat(), "timeZone": time_zone}, + "end": {"dateTime": end_dt.isoformat(), "timeZone": time_zone}, + "visibility": visibility.value, + } - created_event = service.events().insert(calendarId=calendar_id, body=event).execute() + if attendee_emails: + event["attendees"] = [{"email": email} for email in attendee_emails] - except HttpError as e: - raise ToolExecutionError( - f"HttpError during execution of '{create_event.__name__}' tool.", str(e) - ) - except Exception as e: - raise ToolExecutionError( - f"Unexpected Error encountered during execution of '{create_event.__name__}' tool.", - str(e), - ) - else: - return {"event": created_event} + created_event = service.events().insert(calendarId=calendar_id, body=event).execute() + return {"event": created_event} @tool( requires_auth=Google( - scopes=["https://www.googleapis.com/auth/calendar.events.readonly"], + scopes=[ + "https://www.googleapis.com/auth/calendar.readonly", + "https://www.googleapis.com/auth/calendar.events", + ], ) ) async def list_events( context: ToolContext, - min_day: Annotated[ - Day, "Filter by events that end on or after this day. Combined with min_time_slot" - ], - min_time_slot: Annotated[ - TimeSlot, "Filter by events that end after this time. Combined with min_day" + min_end_datetime: Annotated[ + str, + "Filter by events that end on or after this datetime in ISO 8601 format, e.g., '2024-09-15T09:00:00'.", ], - max_day: Annotated[ - Day, "Filter by events that start on or before this day. Combined with max_time_slot" - ], - max_time_slot: Annotated[ - TimeSlot, "Filter by events that start before this time. Combined with max_day" + max_start_datetime: Annotated[ + str, + "Filter by events that start before this datetime in ISO 8601 format, e.g., '2024-09-16T17:00:00'.", ], calendar_id: Annotated[str, "The ID of the calendar to list events from"] = "primary", max_results: Annotated[int, "The maximum number of events to return"] = 10, ) -> Annotated[dict, "A dictionary containing the list of events"]: """ - List events from the specified calendar within the given date range. + List events from the specified calendar within the given datetime range. - min_day and min_time_slot are combined to form the lower bound (exclusive) for an event's end time to filter by - max_day and max_time_slot are combined to form the upper bound (exclusive) for an event's start time to filter by + min_end_datetime serves as the lower bound (exclusive) for an event's end time. + max_start_datetime serves as the upper bound (exclusive) for an event's start time. For example: - If min_day is set to Day.TODAY and min_time_slot is set to TimeSlot._09:00, - and max_day is set to Day.TOMORROW and max_time_slot is set to TimeSlot._17:00, + If min_end_datetime is set to 2024-09-15T09:00:00 and max_start_datetime is set to 2024-09-16T17:00:00, the function will return events that: - 1. End after 09:00 today (exclusive) - 2. Start before 17:00 tomorrow (exclusive) - This means an event starting at 08:00 today and ending at 10:00 today would be included, - but an event starting at 17:00 tomorrow would not be included. + 1. End after 09:00 on September 15, 2024 (exclusive) + 2. Start before 17:00 on September 16, 2024 (exclusive) + This means an event starting at 08:00 on September 15 and ending at 10:00 on September 15 would be included, + but an event starting at 17:00 on September 16 would not be included. """ service = build("calendar", "v3", credentials=Credentials(context.authorization.token)) @@ -123,23 +112,19 @@ async def list_events( calendar = service.calendars().get(calendarId=calendar_id).execute() time_zone = calendar["timeZone"] - # Convert enum values to datetime with timezone offset - start_datetime = datetime.combine( - min_day.to_date(time_zone), min_time_slot.to_time() - ).astimezone(ZoneInfo(time_zone)) - end_datetime = datetime.combine(max_day.to_date(time_zone), max_time_slot.to_time()).astimezone( - ZoneInfo(time_zone) - ) + # Parse datetime strings + min_end_dt = parse_datetime(min_end_datetime, time_zone) + max_start_dt = parse_datetime(max_start_datetime, time_zone) - if start_datetime > end_datetime: - start_datetime, end_datetime = end_datetime, start_datetime + if min_end_dt > max_start_dt: + min_end_dt, max_start_dt = max_start_dt, min_end_dt events_result = ( service.events() .list( calendarId=calendar_id, - timeMin=start_datetime.isoformat(), - timeMax=end_datetime.isoformat(), + timeMin=min_end_dt.isoformat(), + timeMax=max_start_dt.isoformat(), maxResults=max_results, singleEvents=True, orderBy="startTime", @@ -179,21 +164,16 @@ async def list_events( async def update_event( context: ToolContext, event_id: Annotated[str, "The ID of the event to update"], - updated_start_day: Annotated[ - Day | None, - "The updated day that the event starts. Combined with updated_start_time to form the new start time", - ] = None, - updated_start_time: Annotated[ - TimeSlot | None, - "The updated time that the event starts. Combined with updated_start_day to form the new start time", + updated_start_datetime: Annotated[ + str | None, + "The updated datetime that the event starts in ISO 8601 format, e.g., '2024-12-31T15:30:00'.", ] = None, - updated_end_day: Annotated[ - Day | None, - "The updated day that the event ends. Combined with updated_end_time to form the new end time", + updated_end_datetime: Annotated[ + str | None, + "The updated datetime that the event ends in ISO 8601 format, e.g., '2024-12-31T17:30:00'.", ] = None, - updated_end_time: Annotated[TimeSlot | None, "The updated time that the event ends"] = None, updated_calendar_id: Annotated[ - str | None, "The updated ID of the calendar containing the event" + str | None, "The updated ID of the calendar containing the event." ] = None, updated_summary: Annotated[str | None, "The updated title of the event"] = None, updated_description: Annotated[str | None, "The updated description of the event"] = None, @@ -201,25 +181,24 @@ async def update_event( updated_visibility: Annotated[EventVisibility | None, "The visibility of the event"] = None, attendee_emails_to_add: Annotated[ list[str] | None, - "The list of updated attendee emails to add. Must be valid email addresses e.g., username@domain.com", + "The list of attendee emails to add. Must be valid email addresses e.g., username@domain.com.", ] = None, attendee_emails_to_remove: Annotated[ list[str] | None, - "The list of attendee emails to remove. Must be valid email addresses e.g., username@domain.com", + "The list of attendee emails to remove. Must be valid email addresses e.g., username@domain.com.", ] = None, send_updates: Annotated[ - SendUpdatesOptions, "Guests who should receive notifications about the event update" + SendUpdatesOptions, "Should attendees be notified of the update? (none, all, external_only)" ] = SendUpdatesOptions.ALL, ) -> Annotated[ str, - "A string containing the updated event details, including the event ID, update timestamp, and a link to view the updated event", + "A string containing the updated event details, including the event ID, update timestamp, and a link to view the updated event.", ]: """ Update an existing event in the specified calendar with the provided details. Only the provided fields will be updated; others will remain unchanged. - `updated_start_day` and `updated_start_time` must be provided together. - `updated_end_day` and `updated_end_time` must be provided together. + `updated_start_datetime` and `updated_end_datetime` are independent and can be provided separately. """ service = build("calendar", "v3", credentials=Credentials(context.authorization.token)) @@ -228,13 +207,13 @@ async def update_event( try: event = service.events().get(calendarId="primary", eventId=event_id).execute() - except HttpError: # TODO: This is a first pass. We should do better. + except HttpError: valid_events_with_id = ( service.events() .list( calendarId="primary", timeMin=(datetime.now() - timedelta(days=2)).isoformat(), - timeMax=(datetime.now() - timedelta(days=2)).isoformat(), + timeMax=(datetime.now() + timedelta(days=365)).isoformat(), maxResults=50, singleEvents=True, orderBy="startTime", @@ -243,14 +222,18 @@ async def update_event( ) raise RetryableToolError( f"Event with ID {event_id} not found.", - additional_prompt_content=f"Here is list of valid events. The event_id parameter must match one of these: {valid_events_with_id}", + additional_prompt_content=f"Here is a list of valid events. The event_id parameter must match one of these: {valid_events_with_id}", retry_after_ms=1000, developer_message=f"Event with ID {event_id} not found. Please try again with a valid event ID.", ) update_fields = { - "start": _update_datetime(updated_start_day, updated_start_time, time_zone), - "end": _update_datetime(updated_end_day, updated_end_time, time_zone), + "start": {"dateTime": updated_start_datetime.isoformat(), "timeZone": time_zone} + if updated_start_datetime + else None, + "end": {"dateTime": updated_end_datetime.isoformat(), "timeZone": time_zone} + if updated_end_datetime + else None, "calendarId": updated_calendar_id, "sendUpdates": send_updates.value if send_updates else None, "summary": updated_summary, @@ -265,12 +248,20 @@ async def update_event( event["attendees"] = [ attendee for attendee in event.get("attendees", []) - if attendee.get("email", "") not in attendee_emails_to_remove + if attendee.get("email", "").lower() + not in [email.lower() for email in attendee_emails_to_remove] ] + if attendee_emails_to_add: - event["attendees"] = event.get("attendees", []) + [ - {"email": email} for email in attendee_emails_to_add + existing_emails = { + attendee.get("email", "").lower() for attendee in event.get("attendees", []) + } + new_attendees = [ + {"email": email} + for email in attendee_emails_to_add + if email.lower() not in existing_emails ] + event["attendees"] = event.get("attendees", []) + new_attendees updated_event = ( service.events() diff --git a/toolkits/google/arcade_google/tools/utils.py b/toolkits/google/arcade_google/tools/utils.py index 44997a7c..cf7d426b 100644 --- a/toolkits/google/arcade_google/tools/utils.py +++ b/toolkits/google/arcade_google/tools/utils.py @@ -1,8 +1,9 @@ -import datetime import re from base64 import urlsafe_b64decode +from datetime import datetime, timedelta from enum import Enum from typing import Any, Optional +from zoneinfo import ZoneInfo from bs4 import BeautifulSoup from google.oauth2.credentials import Credentials @@ -11,6 +12,31 @@ from arcade_google.tools.models import Day, TimeSlot +def parse_datetime(datetime_str: str, time_zone: str) -> datetime: + """ + Parse a datetime string in ISO 8601 format and ensure it is timezone-aware. + + Args: + datetime_str (str): The datetime string to parse. Expected format: 'YYYY-MM-DDTHH:MM:SS'. + time_zone (str): The timezone to apply if the datetime string is naive. + + Returns: + datetime: A timezone-aware datetime object. + + Raises: + ValueError: If the datetime string is not in the correct format. + """ + try: + dt = datetime.fromisoformat(datetime_str) + if dt.tzinfo is None: + dt = dt.replace(tzinfo=ZoneInfo(time_zone)) + except ValueError as e: + raise ValueError( + f"Invalid datetime format: '{datetime_str}'. Expected ISO 8601 format, e.g., '2024-12-31T15:30:00'." + ) from e + return dt + + class DateRange(Enum): TODAY = "today" YESTERDAY = "yesterday" @@ -21,24 +47,24 @@ class DateRange(Enum): THIS_YEAR = "this_year" def to_date_query(self): - today = datetime.datetime.now() + today = datetime.now() result = "after:" comparison_date = today if self == DateRange.YESTERDAY: - comparison_date = today - datetime.timedelta(days=1) + comparison_date = today - timedelta(days=1) elif self == DateRange.LAST_7_DAYS: - comparison_date = today - datetime.timedelta(days=7) + comparison_date = today - timedelta(days=7) elif self == DateRange.LAST_30_DAYS: - comparison_date = today - datetime.timedelta(days=30) + comparison_date = today - timedelta(days=30) elif self == DateRange.THIS_MONTH: comparison_date = today.replace(day=1) elif self == DateRange.LAST_MONTH: - comparison_date = (today.replace(day=1) - datetime.timedelta(days=1)).replace(day=1) + comparison_date = (today.replace(day=1) - timedelta(days=1)).replace(day=1) elif self == DateRange.THIS_YEAR: comparison_date = today.replace(month=1, day=1) elif self == DateRange.LAST_MONTH: - comparison_date = (today.replace(month=1, day=1) - datetime.timedelta(days=1)).replace( + comparison_date = (today.replace(month=1, day=1) - timedelta(days=1)).replace( month=1, day=1 ) diff --git a/toolkits/google/evals/eval_google_calendar.py b/toolkits/google/evals/eval_google_calendar.py index c7accdcc..105eef69 100644 --- a/toolkits/google/evals/eval_google_calendar.py +++ b/toolkits/google/evals/eval_google_calendar.py @@ -1,10 +1,19 @@ +from datetime import timedelta + import arcade_google -from arcade_google.tools.calendar import create_event, delete_event, list_events, update_event -from arcade_google.tools.models import Day, EventVisibility, TimeSlot +from arcade_google.tools.calendar import ( + EventVisibility, + SendUpdatesOptions, + create_event, + delete_event, + list_events, + update_event, +) from arcade.core.catalog import ToolCatalog from arcade.sdk.eval import ( BinaryCritic, + DatetimeCritic, EvalRubric, EvalSuite, tool_eval, @@ -16,11 +25,9 @@ warn_threshold=0.95, ) - catalog = ToolCatalog() catalog.add_module(arcade_google) - history_after_list_events = [ {"role": "user", "content": "do i have any events on my calendar for today?"}, { @@ -36,7 +43,7 @@ "type": "function", "function": { "name": "Google_ListEvents", - "arguments": '{"max_day":"today","max_time_slot":"23:45","min_day":"today","min_time_slot":"00:00"}', + "arguments": '{"min_end_datetime":"2024-09-26T00:00:00-07:00","max_start_datetime":"2024-09-27T00:00:00-07:00"}', }, } ], @@ -59,7 +66,9 @@ def calendar_eval_suite() -> EvalSuite: """Create an evaluation suite for Calendar tools.""" suite = EvalSuite( name="Calendar Tools Evaluation", - system_message="You are an AI assistant that can create and list events using the provided tools.", + system_message=( + "You are an AI assistant that can create, list, update, and delete events using the provided tools. Today is 2024-09-26" + ), catalog=catalog, rubric=rubric, ) @@ -67,32 +76,34 @@ def calendar_eval_suite() -> EvalSuite: # Cases for create_event suite.add_case( name="Create calendar event", - user_message="Create a meeting for 'Team Meeting' starting next thursday from 11:45pm to 12:15am. Invite johndoe@example.com", + user_message=( + "Create a meeting for 'Team Meeting' starting on September 26, 2024, from 11:45pm to 12:15am. Invite johndoe@example.com" + ), expected_tool_calls=[ ( create_event, { "summary": "Team Meeting", - "start_date": Day.NEXT_THURSDAY.value, - "start_time": TimeSlot._2345.value, - "end_date": Day.NEXT_FRIDAY.value, - "end_time": TimeSlot._0015.value, + "start_datetime": "2024-09-26T23:45:00", + "end_datetime": "2024-09-27T00:15:00", "calendar_id": "primary", "attendee_emails": ["johndoe@example.com"], - "description": None, - "location": None, "visibility": EventVisibility.DEFAULT, + "description": "Team Meeting", }, ) ], critics=[ - BinaryCritic(critic_field="summary", weight=0.15), - BinaryCritic(critic_field="start_date", weight=0.15), - BinaryCritic(critic_field="start_time", weight=0.15), - BinaryCritic(critic_field="end_date", weight=0.15), - BinaryCritic(critic_field="end_time", weight=0.15), - BinaryCritic(critic_field="attendee_emails", weight=0.15), + BinaryCritic(critic_field="summary", weight=0.2), + DatetimeCritic( + critic_field="start_datetime", weight=0.2, tolerance=timedelta(seconds=10) + ), + DatetimeCritic( + critic_field="end_datetime", weight=0.2, tolerance=timedelta(seconds=10) + ), + BinaryCritic(critic_field="attendee_emails", weight=0.2), BinaryCritic(critic_field="description", weight=0.1), + BinaryCritic(critic_field="location", weight=0.1), ], ) @@ -104,49 +115,61 @@ def calendar_eval_suite() -> EvalSuite: ( list_events, { - "min_day": Day.TODAY.value, - "min_time_slot": TimeSlot._0000.value, - "max_day": Day.TOMORROW.value, - "max_time_slot": TimeSlot._0000.value, + "min_end_datetime": "2024-09-26T00:00:00", + "max_start_datetime": "2024-09-27T00:00:00", "calendar_id": "primary", - "event_types": None, "max_results": 10, }, ) ], critics=[ - BinaryCritic(critic_field="min_day", weight=0.1), - BinaryCritic(critic_field="min_time_slot", weight=0.1), - BinaryCritic(critic_field="max_day", weight=0.1), - BinaryCritic(critic_field="max_time_slot", weight=0.1), - BinaryCritic(critic_field="calendar_id", weight=0.1), - BinaryCritic(critic_field="event_types", weight=0.1), - BinaryCritic(critic_field="max_results", weight=0.1), + DatetimeCritic( + critic_field="min_end_datetime", weight=0.3, tolerance=timedelta(hours=1) + ), + DatetimeCritic( + critic_field="max_start_datetime", weight=0.3, tolerance=timedelta(hours=1) + ), + BinaryCritic(critic_field="calendar_id", weight=0.2), + BinaryCritic(critic_field="max_results", weight=0.2), ], ) # Cases for update_event suite.add_case( name="Update a calendar event", - user_message="Oh no! I cant make it to the API Test since i have lunch with an old friend at that time. Change the meeting to 3pm to 4pm please.", + user_message=( + "Oh no! I can't make it to the API Test since I have lunch with an old friend at that time. " + "Change the meeting my meeting tomorrow at 3pm to 4pm. Let everyone know." + ), expected_tool_calls=[ ( update_event, { "event_id": "00099992228181818181", - "updated_start_day": Day.TODAY.value, - "updated_start_time": TimeSlot._1500.value, - "updated_end_day": Day.TODAY.value, - "updated_end_time": TimeSlot._1600.value, + "updated_start_datetime": "2024-09-27T16:00:00", + "updated_end_datetime": "2024-09-27T18:00:00", + "updated_calendar_id": "primary", + "updated_summary": "API Test", + "updated_description": "API Test", + "updated_location": "611 Gateway Blvd", + "updated_visibility": EventVisibility.DEFAULT, + "attendee_emails_to_add": None, + "attendee_emails_to_remove": None, + "send_updates": SendUpdatesOptions.ALL, }, ) ], critics=[ - BinaryCritic(critic_field="event_id", weight=0.2), - BinaryCritic(critic_field="updated_start_day", weight=0.1), - BinaryCritic(critic_field="updated_start_time", weight=0.1), - BinaryCritic(critic_field="updated_end_day", weight=0.1), - BinaryCritic(critic_field="updated_end_time", weight=0.1), + BinaryCritic(critic_field="event_id", weight=0.4), + DatetimeCritic( + critic_field="updated_start_datetime", weight=0.2, tolerance=timedelta(minutes=15) + ), + DatetimeCritic( + critic_field="updated_end_datetime", + weight=0.2, + tolerance=timedelta(minutes=15), + ), + BinaryCritic(critic_field="send_updates", weight=0.2), ], additional_messages=history_after_list_events, ) @@ -154,14 +177,16 @@ def calendar_eval_suite() -> EvalSuite: # Cases for delete_event suite.add_case( name="Delete a calendar event", - user_message="I don't need to have focus time today. Please delete it from my calendar. Don't send any notifications.", + user_message=( + "I don't need to have focus time today. Please delete it from my calendar. Don't send any notifications." + ), expected_tool_calls=[ ( delete_event, { "event_id": "gr5g18lf88tfpp3vkareukkc7g", "calendar_id": "primary", - "send_updates": "none", + "send_updates": SendUpdatesOptions.NONE, }, ) ], diff --git a/toolkits/google/tests/test_calendar.py b/toolkits/google/tests/test_calendar.py index f06fbc22..7a2f7020 100644 --- a/toolkits/google/tests/test_calendar.py +++ b/toolkits/google/tests/test_calendar.py @@ -2,7 +2,7 @@ import pytest from arcade_google.tools.calendar import create_event, delete_event, list_events, update_event -from arcade_google.tools.models import Day, EventVisibility, SendUpdatesOptions, TimeSlot +from arcade_google.tools.models import EventVisibility, SendUpdatesOptions from googleapiclient.errors import HttpError from arcade.core.errors import ToolExecutionError @@ -24,7 +24,7 @@ async def test_create_event(mock_build, mock_context): # Mock the calendar's time zone mock_service.calendars().get().execute.return_value = {"timeZone": "America/Los_Angeles"} - # Case: HttpError + # Case: HttpError during event creation mock_service.events().insert().execute.side_effect = HttpError( resp=MagicMock(status=400), content=b'{"error": {"message": "Invalid request"}}', @@ -34,10 +34,8 @@ async def test_create_event(mock_build, mock_context): await create_event( context=mock_context, summary="Test Event", - start_date=Day.TODAY, - start_time=TimeSlot._1615, - end_date=Day.TODAY, - end_time=TimeSlot._1715, + start_datetime="2024-12-31T15:30:00", + end_datetime="2024-12-31T17:30:00", description="Test Description", location="Test Location", visibility=EventVisibility.PRIVATE, @@ -53,34 +51,34 @@ async def test_list_events(mock_build, mock_context): # Mock the calendar's time zone mock_service.calendars().get().execute.return_value = {"timeZone": "America/Los_Angeles"} - # Case: min time is after max time. list_events tool should swap the times and still return the events + # Mock the events list response mock_events_list_response = { "items": [ { "creator": {"email": "example@arcade-ai.com", "self": True}, "end": {"dateTime": "2024-09-27T01:00:00-07:00", "timeZone": "America/Los_Angeles"}, "eventType": "default", - "htmlLink": "https://www.google.com/calendar/event?eid=N2pmYjZ0ZmNnMGNydG5scmhkY2JvZWc4OGIgZXJpY0BhcmNhZGUtYWku", - "id": "7jfb6tfcg0crtnlrhdcboeg88b", + "htmlLink": "https://www.google.com/calendar/event?eid=event1", + "id": "event1", "organizer": {"email": "example@arcade-ai.com", "self": True}, "start": { "dateTime": "2024-09-27T00:00:00-07:00", "timeZone": "America/Los_Angeles", }, - "summary": "teST", + "summary": "Event 1", }, { "creator": {"email": "example@arcade-ai.com", "self": True}, "end": {"dateTime": "2024-09-27T17:00:00-07:00", "timeZone": "America/Los_Angeles"}, "eventType": "default", - "htmlLink": "https://www.google.com/calendar/event?eid=MjZvYnRoc2xtMWMzbG5mdG10bzk4cDcxaGMgZXJpY0BhcmNhZGUtYWku", - "id": "26obthslm1c3lnftmto98p71hc", + "htmlLink": "https://www.google.com/calendar/event?eid=event2", + "id": "event2", "organizer": {"email": "example@arcade-ai.com", "self": True}, "start": { "dateTime": "2024-09-27T14:00:00-07:00", "timeZone": "America/Los_Angeles", }, - "summary": "New Event", + "summary": "Event 2", }, ] } @@ -89,16 +87,14 @@ async def test_list_events(mock_build, mock_context): "events": mock_events_list_response["items"], } mock_service.events().list().execute.return_value = mock_events_list_response - message = await list_events( + response = await list_events( context=mock_context, - min_day=Day.TODAY, - min_time_slot=TimeSlot._1615, - max_day=Day.TODAY, - max_time_slot=TimeSlot._1515, + min_end_datetime="2024-09-15T09:00:00", + max_start_datetime="2024-09-16T17:00:00", ) - assert message == expected_tool_response + assert response == expected_tool_response - # Case: HttpError + # Case: HttpError during events listing mock_service.events().list().execute.side_effect = HttpError( resp=MagicMock(status=400), content=b'{"error": {"message": "Invalid request"}}', @@ -107,10 +103,8 @@ async def test_list_events(mock_build, mock_context): with pytest.raises(ToolExecutionError): await list_events( context=mock_context, - min_day=Day.TODAY, - min_time_slot=TimeSlot._1615, - max_day=Day.TOMORROW, - max_time_slot=TimeSlot._1815, + min_end_datetime="2024-09-15T09:00:00", + max_start_datetime="2024-09-16T17:00:00", ) @@ -119,8 +113,10 @@ async def test_list_events(mock_build, mock_context): async def test_update_event(mock_build, mock_context): mock_service = MagicMock() mock_build.return_value = mock_service - mock_service.events().update().execute.side_effect = HttpError( - resp=MagicMock(status=400), + + # Mock retrieval of the event + mock_service.events().get().execute.side_effect = HttpError( + resp=MagicMock(status=404), content=b'{"error": {"message": "Event not found"}}', ) @@ -128,10 +124,8 @@ async def test_update_event(mock_build, mock_context): await update_event( context=mock_context, event_id="1234567890", - updated_start_day=Day.NEXT_FRIDAY, - updated_start_time=TimeSlot._0015, - updated_end_day=Day.NEXT_FRIDAY, - updated_end_time=TimeSlot._0115, + updated_start_datetime="2024-12-31T00:15:00", + updated_end_datetime="2024-12-31T01:15:00", updated_summary="Updated Event", updated_description="Updated Description", updated_location="Updated Location", @@ -148,7 +142,7 @@ async def test_delete_event(mock_build, mock_context): mock_service = MagicMock() mock_build.return_value = mock_service mock_service.events().delete().execute.side_effect = HttpError( - resp=MagicMock(status=400), + resp=MagicMock(status=404), content=b'{"error": {"message": "Event not found"}}', ) @@ -156,4 +150,5 @@ async def test_delete_event(mock_build, mock_context): await delete_event( context=mock_context, event_id="nonexistent_event", + send_updates=SendUpdatesOptions.ALL, ) diff --git a/toolkits/math/tests/test_arithmetic.py b/toolkits/math/tests/test_arithmetic.py index af40a9cb..67d307e9 100644 --- a/toolkits/math/tests/test_arithmetic.py +++ b/toolkits/math/tests/test_arithmetic.py @@ -9,6 +9,8 @@ sum_range, ) +from arcade.sdk.error import ToolExecutionError + def test_add(): assert add(1, 2) == 3 @@ -29,7 +31,7 @@ def test_multiply(): def test_divide(): assert divide(6, 3) == 2.0 assert divide(5, 2) == 2.5 - with pytest.raises(ZeroDivisionError): + with pytest.raises(ToolExecutionError): divide(1, 0)