diff --git a/dotnet/test/ToolsTests.cs b/dotnet/test/ToolsTests.cs index 942a09a09..c6449ec8f 100644 --- a/dotnet/test/ToolsTests.cs +++ b/dotnet/test/ToolsTests.cs @@ -5,6 +5,7 @@ using GitHub.Copilot.SDK.Test.Harness; using Microsoft.Extensions.AI; using System.ComponentModel; +using System.Linq; using System.Text.Json; using System.Text.Json.Serialization; using Xunit; @@ -42,6 +43,7 @@ public async Task Invokes_Custom_Tool() var session = await CreateSessionAsync(new SessionConfig { Tools = [AIFunctionFactory.Create(EncryptString, "encrypt_string")], + OnPermissionRequest = PermissionHandler.ApproveAll, }); await session.SendAsync(new MessageOptions @@ -66,7 +68,8 @@ public async Task Handles_Tool_Calling_Errors() var session = await CreateSessionAsync(new SessionConfig { - Tools = [getUserLocation] + Tools = [getUserLocation], + OnPermissionRequest = PermissionHandler.ApproveAll, }); await session.SendAsync(new MessageOptions { Prompt = "What is my location? If you can't find out, just say 'unknown'." }); @@ -108,6 +111,7 @@ public async Task Can_Receive_And_Return_Complex_Types() var session = await CreateSessionAsync(new SessionConfig { Tools = [AIFunctionFactory.Create(PerformDbQuery, "db_query", serializerOptions: ToolsTestsJsonContext.Default.Options)], + OnPermissionRequest = PermissionHandler.ApproveAll, }); await session.SendAsync(new MessageOptions @@ -154,6 +158,7 @@ public async Task Can_Return_Binary_Result() var session = await CreateSessionAsync(new SessionConfig { Tools = [AIFunctionFactory.Create(GetImage, "get_image")], + OnPermissionRequest = PermissionHandler.ApproveAll, }); await session.SendAsync(new MessageOptions @@ -177,4 +182,72 @@ await session.SendAsync(new MessageOptions SessionLog = "Returned an image", }); } + + [Fact] + public async Task Invokes_Custom_Tool_With_Permission_Handler() + { + var permissionRequests = new List(); + + var session = await Client.CreateSessionAsync(new SessionConfig + { + Tools = [AIFunctionFactory.Create(EncryptStringForPermission, "encrypt_string")], + OnPermissionRequest = (request, invocation) => + { + permissionRequests.Add(request); + return Task.FromResult(new PermissionRequestResult { Kind = "approved" }); + }, + }); + + await session.SendAsync(new MessageOptions + { + Prompt = "Use encrypt_string to encrypt this string: Hello" + }); + + var assistantMessage = await TestHelper.GetFinalAssistantMessageAsync(session); + Assert.NotNull(assistantMessage); + Assert.Contains("HELLO", assistantMessage!.Data.Content ?? string.Empty); + + // Should have received a custom-tool permission request with the correct tool name + var customToolRequest = permissionRequests.FirstOrDefault(r => r.Kind == "custom-tool"); + Assert.NotNull(customToolRequest); + Assert.True(customToolRequest!.ExtensionData?.ContainsKey("toolName") ?? false); + var toolName = ((JsonElement)customToolRequest.ExtensionData!["toolName"]).GetString(); + Assert.Equal("encrypt_string", toolName); + + [Description("Encrypts a string")] + static string EncryptStringForPermission([Description("String to encrypt")] string input) + => input.ToUpperInvariant(); + } + + [Fact] + public async Task Denies_Custom_Tool_When_Permission_Denied() + { + var toolHandlerCalled = false; + + var session = await Client.CreateSessionAsync(new SessionConfig + { + Tools = [AIFunctionFactory.Create(EncryptStringDenied, "encrypt_string")], + OnPermissionRequest = (request, invocation) => + { + return Task.FromResult(new PermissionRequestResult { Kind = "denied-interactively-by-user" }); + }, + }); + + await session.SendAsync(new MessageOptions + { + Prompt = "Use encrypt_string to encrypt this string: Hello" + }); + + await TestHelper.GetFinalAssistantMessageAsync(session); + + // The tool handler should NOT have been called since permission was denied + Assert.False(toolHandlerCalled); + + [Description("Encrypts a string")] + string EncryptStringDenied([Description("String to encrypt")] string input) + { + toolHandlerCalled = true; + return input.ToUpperInvariant(); + } + } } diff --git a/go/internal/e2e/tools_test.go b/go/internal/e2e/tools_test.go index b38e41a60..e5b93fa25 100644 --- a/go/internal/e2e/tools_test.go +++ b/go/internal/e2e/tools_test.go @@ -5,6 +5,7 @@ import ( "os" "path/filepath" "strings" + "sync" "testing" copilot "github.com/github/copilot-sdk/go" @@ -262,4 +263,103 @@ func TestTools(t *testing.T) { t.Errorf("Expected session ID '%s', got '%s'", session.SessionID, receivedInvocation.SessionID) } }) + + t.Run("invokes custom tool with permission handler", func(t *testing.T) { + ctx.ConfigureForTest(t) + + type EncryptParams struct { + Input string `json:"input" jsonschema:"String to encrypt"` + } + + var permissionRequests []copilot.PermissionRequest + var mu sync.Mutex + + session, err := client.CreateSession(t.Context(), &copilot.SessionConfig{ + Tools: []copilot.Tool{ + copilot.DefineTool("encrypt_string", "Encrypts a string", + func(params EncryptParams, inv copilot.ToolInvocation) (string, error) { + return strings.ToUpper(params.Input), nil + }), + }, + OnPermissionRequest: func(request copilot.PermissionRequest, invocation copilot.PermissionInvocation) (copilot.PermissionRequestResult, error) { + mu.Lock() + permissionRequests = append(permissionRequests, request) + mu.Unlock() + return copilot.PermissionRequestResult{Kind: "approved"}, nil + }, + }) + if err != nil { + t.Fatalf("Failed to create session: %v", err) + } + + _, err = session.Send(t.Context(), copilot.MessageOptions{Prompt: "Use encrypt_string to encrypt this string: Hello"}) + if err != nil { + t.Fatalf("Failed to send message: %v", err) + } + + answer, err := testharness.GetFinalAssistantMessage(t.Context(), session) + if err != nil { + t.Fatalf("Failed to get assistant message: %v", err) + } + + if answer.Data.Content == nil || !strings.Contains(*answer.Data.Content, "HELLO") { + t.Errorf("Expected answer to contain 'HELLO', got %v", answer.Data.Content) + } + + // Should have received a custom-tool permission request + mu.Lock() + customToolReqs := 0 + for _, req := range permissionRequests { + if req.Kind == "custom-tool" { + customToolReqs++ + if toolName, ok := req.Extra["toolName"].(string); !ok || toolName != "encrypt_string" { + t.Errorf("Expected toolName 'encrypt_string', got '%v'", req.Extra["toolName"]) + } + } + } + mu.Unlock() + if customToolReqs == 0 { + t.Errorf("Expected at least one custom-tool permission request, got none") + } + }) + + t.Run("denies custom tool when permission denied", func(t *testing.T) { + ctx.ConfigureForTest(t) + + type EncryptParams struct { + Input string `json:"input" jsonschema:"String to encrypt"` + } + + toolHandlerCalled := false + + session, err := client.CreateSession(t.Context(), &copilot.SessionConfig{ + Tools: []copilot.Tool{ + copilot.DefineTool("encrypt_string", "Encrypts a string", + func(params EncryptParams, inv copilot.ToolInvocation) (string, error) { + toolHandlerCalled = true + return strings.ToUpper(params.Input), nil + }), + }, + OnPermissionRequest: func(request copilot.PermissionRequest, invocation copilot.PermissionInvocation) (copilot.PermissionRequestResult, error) { + return copilot.PermissionRequestResult{Kind: "denied-interactively-by-user"}, nil + }, + }) + if err != nil { + t.Fatalf("Failed to create session: %v", err) + } + + _, err = session.Send(t.Context(), copilot.MessageOptions{Prompt: "Use encrypt_string to encrypt this string: Hello"}) + if err != nil { + t.Fatalf("Failed to send message: %v", err) + } + + _, err = testharness.GetFinalAssistantMessage(t.Context(), session) + if err != nil { + t.Fatalf("Failed to get assistant message: %v", err) + } + + if toolHandlerCalled { + t.Errorf("Tool handler should NOT have been called since permission was denied") + } + }) } diff --git a/go/types.go b/go/types.go index f3f299ed5..225cc1266 100644 --- a/go/types.go +++ b/go/types.go @@ -106,6 +106,32 @@ type PermissionRequest struct { Extra map[string]any `json:"-"` // Additional fields vary by kind } +// UnmarshalJSON implements custom JSON unmarshaling for PermissionRequest +// to capture additional fields (varying by kind) into the Extra map. +func (p *PermissionRequest) UnmarshalJSON(data []byte) error { + // Unmarshal known fields via an alias to avoid infinite recursion + type Alias PermissionRequest + var alias Alias + if err := json.Unmarshal(data, &alias); err != nil { + return err + } + *p = PermissionRequest(alias) + + // Unmarshal all fields into a generic map + var raw map[string]any + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + + // Remove known fields, keep the rest as Extra + delete(raw, "kind") + delete(raw, "toolCallId") + if len(raw) > 0 { + p.Extra = raw + } + return nil +} + // PermissionRequestResult represents the result of a permission request type PermissionRequestResult struct { Kind string `json:"kind"` diff --git a/nodejs/src/types.ts b/nodejs/src/types.ts index c016edff2..3a0ccbce7 100644 --- a/nodejs/src/types.ts +++ b/nodejs/src/types.ts @@ -211,7 +211,7 @@ export type SystemMessageConfig = SystemMessageAppendConfig | SystemMessageRepla * Permission request types from the server */ export interface PermissionRequest { - kind: "shell" | "write" | "mcp" | "read" | "url"; + kind: "shell" | "write" | "mcp" | "read" | "url" | "custom-tool"; toolCallId?: string; [key: string]: unknown; } diff --git a/nodejs/test/e2e/tools.test.ts b/nodejs/test/e2e/tools.test.ts index a6ad0c049..feab2fbfa 100644 --- a/nodejs/test/e2e/tools.test.ts +++ b/nodejs/test/e2e/tools.test.ts @@ -7,6 +7,7 @@ import { join } from "path"; import { assert, describe, expect, it } from "vitest"; import { z } from "zod"; import { defineTool, approveAll } from "../../src/index.js"; +import type { PermissionRequest } from "../../src/index.js"; import { createSdkTestContext } from "./harness/sdkTestContext"; describe("Custom tools", async () => { @@ -36,6 +37,7 @@ describe("Custom tools", async () => { handler: ({ input }) => input.toUpperCase(), }), ], + onPermissionRequest: approveAll, }); const assistantMessage = await session.sendAndWait({ @@ -55,6 +57,7 @@ describe("Custom tools", async () => { }, }), ], + onPermissionRequest: approveAll, }); const answer = await session.sendAndWait({ @@ -111,6 +114,7 @@ describe("Custom tools", async () => { }, }), ], + onPermissionRequest: approveAll, }); const assistantMessage = await session.sendAndWait({ @@ -127,4 +131,63 @@ describe("Custom tools", async () => { expect(responseContent.replace(/,/g, "")).toContain("135460"); expect(responseContent.replace(/,/g, "")).toContain("204356"); }); + + it("invokes custom tool with permission handler", async () => { + const permissionRequests: PermissionRequest[] = []; + + const session = await client.createSession({ + tools: [ + defineTool("encrypt_string", { + description: "Encrypts a string", + parameters: z.object({ + input: z.string().describe("String to encrypt"), + }), + handler: ({ input }) => input.toUpperCase(), + }), + ], + onPermissionRequest: (request) => { + permissionRequests.push(request); + return { kind: "approved" }; + }, + }); + + const assistantMessage = await session.sendAndWait({ + prompt: "Use encrypt_string to encrypt this string: Hello", + }); + expect(assistantMessage?.data.content).toContain("HELLO"); + + // Should have received a custom-tool permission request + const customToolRequests = permissionRequests.filter((req) => req.kind === "custom-tool"); + expect(customToolRequests.length).toBeGreaterThan(0); + expect(customToolRequests[0].toolName).toBe("encrypt_string"); + }); + + it("denies custom tool when permission denied", async () => { + let toolHandlerCalled = false; + + const session = await client.createSession({ + tools: [ + defineTool("encrypt_string", { + description: "Encrypts a string", + parameters: z.object({ + input: z.string().describe("String to encrypt"), + }), + handler: ({ input }) => { + toolHandlerCalled = true; + return input.toUpperCase(); + }, + }), + ], + onPermissionRequest: () => { + return { kind: "denied-interactively-by-user" }; + }, + }); + + await session.sendAndWait({ + prompt: "Use encrypt_string to encrypt this string: Hello", + }); + + // The tool handler should NOT have been called since permission was denied + expect(toolHandlerCalled).toBe(false); + }); }); diff --git a/python/copilot/types.py b/python/copilot/types.py index e89399777..142aee474 100644 --- a/python/copilot/types.py +++ b/python/copilot/types.py @@ -169,7 +169,7 @@ class SystemMessageReplaceConfig(TypedDict): class PermissionRequest(TypedDict, total=False): """Permission request from the server""" - kind: Literal["shell", "write", "mcp", "read", "url"] + kind: Literal["shell", "write", "mcp", "read", "url", "custom-tool"] toolCallId: str # Additional fields vary by kind diff --git a/python/e2e/test_tools.py b/python/e2e/test_tools.py index 485998e00..e4a9f5f06 100644 --- a/python/e2e/test_tools.py +++ b/python/e2e/test_tools.py @@ -132,3 +132,61 @@ def db_query(params: DbQueryParams, invocation: ToolInvocation) -> list[City]: assert "San Lorenzo" in response_content assert "135460" in response_content.replace(",", "") assert "204356" in response_content.replace(",", "") + + async def test_invokes_custom_tool_with_permission_handler(self, ctx: E2ETestContext): + class EncryptParams(BaseModel): + input: str = Field(description="String to encrypt") + + @define_tool("encrypt_string", description="Encrypts a string") + def encrypt_string(params: EncryptParams, invocation: ToolInvocation) -> str: + return params.input.upper() + + permission_requests = [] + + def on_permission_request(request, invocation): + permission_requests.append(request) + return {"kind": "approved"} + + session = await ctx.client.create_session( + { + "tools": [encrypt_string], + "on_permission_request": on_permission_request, + } + ) + + await session.send({"prompt": "Use encrypt_string to encrypt this string: Hello"}) + assistant_message = await get_final_assistant_message(session) + assert "HELLO" in assistant_message.data.content + + # Should have received a custom-tool permission request + custom_tool_requests = [r for r in permission_requests if r.get("kind") == "custom-tool"] + assert len(custom_tool_requests) > 0 + assert custom_tool_requests[0].get("toolName") == "encrypt_string" + + async def test_denies_custom_tool_when_permission_denied(self, ctx: E2ETestContext): + tool_handler_called = False + + class EncryptParams(BaseModel): + input: str = Field(description="String to encrypt") + + @define_tool("encrypt_string", description="Encrypts a string") + def encrypt_string(params: EncryptParams, invocation: ToolInvocation) -> str: + nonlocal tool_handler_called + tool_handler_called = True + return params.input.upper() + + def on_permission_request(request, invocation): + return {"kind": "denied-interactively-by-user"} + + session = await ctx.client.create_session( + { + "tools": [encrypt_string], + "on_permission_request": on_permission_request, + } + ) + + await session.send({"prompt": "Use encrypt_string to encrypt this string: Hello"}) + await get_final_assistant_message(session) + + # The tool handler should NOT have been called since permission was denied + assert not tool_handler_called diff --git a/test/snapshots/tools/denies_custom_tool_when_permission_denied.yaml b/test/snapshots/tools/denies_custom_tool_when_permission_denied.yaml new file mode 100644 index 000000000..47f9286e0 --- /dev/null +++ b/test/snapshots/tools/denies_custom_tool_when_permission_denied.yaml @@ -0,0 +1,15 @@ +models: + - claude-sonnet-4.5 +conversations: + - messages: + - role: system + content: ${system} + - role: user + content: "Use encrypt_string to encrypt this string: Hello" + - role: assistant + tool_calls: + - id: toolcall_0 + type: function + function: + name: encrypt_string + arguments: '{"input":"Hello"}' diff --git a/test/snapshots/tools/invokes_custom_tool_with_permission_handler.yaml b/test/snapshots/tools/invokes_custom_tool_with_permission_handler.yaml new file mode 100644 index 000000000..5b046d4c3 --- /dev/null +++ b/test/snapshots/tools/invokes_custom_tool_with_permission_handler.yaml @@ -0,0 +1,20 @@ +models: + - claude-sonnet-4.5 +conversations: + - messages: + - role: system + content: ${system} + - role: user + content: "Use encrypt_string to encrypt this string: Hello" + - role: assistant + tool_calls: + - id: toolcall_0 + type: function + function: + name: encrypt_string + arguments: '{"input":"Hello"}' + - role: tool + tool_call_id: toolcall_0 + content: HELLO + - role: assistant + content: "The encrypted result is: **HELLO**"