Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 74 additions & 1 deletion dotnet/test/ToolsTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using GitHub.Copilot.SDK.Test.Harness;
using Microsoft.Extensions.AI;
using System.ComponentModel;
using System.Linq;
using System.Text.Json;
using System.Text.Json.Serialization;
using Xunit;
Expand Down Expand Up @@ -42,6 +43,7 @@ public async Task Invokes_Custom_Tool()
var session = await CreateSessionAsync(new SessionConfig
{
Tools = [AIFunctionFactory.Create(EncryptString, "encrypt_string")],
OnPermissionRequest = PermissionHandler.ApproveAll,
});

await session.SendAsync(new MessageOptions
Expand All @@ -66,7 +68,8 @@ public async Task Handles_Tool_Calling_Errors()

var session = await CreateSessionAsync(new SessionConfig
{
Tools = [getUserLocation]
Tools = [getUserLocation],
OnPermissionRequest = PermissionHandler.ApproveAll,
});

await session.SendAsync(new MessageOptions { Prompt = "What is my location? If you can't find out, just say 'unknown'." });
Expand Down Expand Up @@ -108,6 +111,7 @@ public async Task Can_Receive_And_Return_Complex_Types()
var session = await CreateSessionAsync(new SessionConfig
{
Tools = [AIFunctionFactory.Create(PerformDbQuery, "db_query", serializerOptions: ToolsTestsJsonContext.Default.Options)],
OnPermissionRequest = PermissionHandler.ApproveAll,
});

await session.SendAsync(new MessageOptions
Expand Down Expand Up @@ -154,6 +158,7 @@ public async Task Can_Return_Binary_Result()
var session = await CreateSessionAsync(new SessionConfig
{
Tools = [AIFunctionFactory.Create(GetImage, "get_image")],
OnPermissionRequest = PermissionHandler.ApproveAll,
});

await session.SendAsync(new MessageOptions
Expand All @@ -177,4 +182,72 @@ await session.SendAsync(new MessageOptions
SessionLog = "Returned an image",
});
}

[Fact]
public async Task Invokes_Custom_Tool_With_Permission_Handler()
{
var permissionRequests = new List<PermissionRequest>();

var session = await Client.CreateSessionAsync(new SessionConfig
{
Tools = [AIFunctionFactory.Create(EncryptStringForPermission, "encrypt_string")],
OnPermissionRequest = (request, invocation) =>
{
permissionRequests.Add(request);
return Task.FromResult(new PermissionRequestResult { Kind = "approved" });
},
});

await session.SendAsync(new MessageOptions
{
Prompt = "Use encrypt_string to encrypt this string: Hello"
});

var assistantMessage = await TestHelper.GetFinalAssistantMessageAsync(session);
Assert.NotNull(assistantMessage);
Assert.Contains("HELLO", assistantMessage!.Data.Content ?? string.Empty);

// Should have received a custom-tool permission request with the correct tool name
var customToolRequest = permissionRequests.FirstOrDefault(r => r.Kind == "custom-tool");
Assert.NotNull(customToolRequest);
Assert.True(customToolRequest!.ExtensionData?.ContainsKey("toolName") ?? false);
var toolName = ((JsonElement)customToolRequest.ExtensionData!["toolName"]).GetString();
Assert.Equal("encrypt_string", toolName);

[Description("Encrypts a string")]
static string EncryptStringForPermission([Description("String to encrypt")] string input)
=> input.ToUpperInvariant();
}

[Fact]
public async Task Denies_Custom_Tool_When_Permission_Denied()
{
var toolHandlerCalled = false;

var session = await Client.CreateSessionAsync(new SessionConfig
{
Tools = [AIFunctionFactory.Create(EncryptStringDenied, "encrypt_string")],
OnPermissionRequest = (request, invocation) =>
{
return Task.FromResult(new PermissionRequestResult { Kind = "denied-interactively-by-user" });
},
});

await session.SendAsync(new MessageOptions
{
Prompt = "Use encrypt_string to encrypt this string: Hello"
});

await TestHelper.GetFinalAssistantMessageAsync(session);

// The tool handler should NOT have been called since permission was denied
Assert.False(toolHandlerCalled);

[Description("Encrypts a string")]
string EncryptStringDenied([Description("String to encrypt")] string input)
{
toolHandlerCalled = true;
return input.ToUpperInvariant();
}
}
}
100 changes: 100 additions & 0 deletions go/internal/e2e/tools_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"os"
"path/filepath"
"strings"
"sync"
"testing"

copilot "github.com/github/copilot-sdk/go"
Expand Down Expand Up @@ -262,4 +263,103 @@ func TestTools(t *testing.T) {
t.Errorf("Expected session ID '%s', got '%s'", session.SessionID, receivedInvocation.SessionID)
}
})

t.Run("invokes custom tool with permission handler", func(t *testing.T) {
ctx.ConfigureForTest(t)

type EncryptParams struct {
Input string `json:"input" jsonschema:"String to encrypt"`
}

var permissionRequests []copilot.PermissionRequest
var mu sync.Mutex

session, err := client.CreateSession(t.Context(), &copilot.SessionConfig{
Tools: []copilot.Tool{
copilot.DefineTool("encrypt_string", "Encrypts a string",
func(params EncryptParams, inv copilot.ToolInvocation) (string, error) {
return strings.ToUpper(params.Input), nil
}),
},
OnPermissionRequest: func(request copilot.PermissionRequest, invocation copilot.PermissionInvocation) (copilot.PermissionRequestResult, error) {
mu.Lock()
permissionRequests = append(permissionRequests, request)
mu.Unlock()
return copilot.PermissionRequestResult{Kind: "approved"}, nil
},
})
if err != nil {
t.Fatalf("Failed to create session: %v", err)
}

_, err = session.Send(t.Context(), copilot.MessageOptions{Prompt: "Use encrypt_string to encrypt this string: Hello"})
if err != nil {
t.Fatalf("Failed to send message: %v", err)
}

answer, err := testharness.GetFinalAssistantMessage(t.Context(), session)
if err != nil {
t.Fatalf("Failed to get assistant message: %v", err)
}

if answer.Data.Content == nil || !strings.Contains(*answer.Data.Content, "HELLO") {
t.Errorf("Expected answer to contain 'HELLO', got %v", answer.Data.Content)
}

// Should have received a custom-tool permission request
mu.Lock()
customToolReqs := 0
for _, req := range permissionRequests {
if req.Kind == "custom-tool" {
customToolReqs++
if toolName, ok := req.Extra["toolName"].(string); !ok || toolName != "encrypt_string" {
t.Errorf("Expected toolName 'encrypt_string', got '%v'", req.Extra["toolName"])
}
}
}
mu.Unlock()
if customToolReqs == 0 {
t.Errorf("Expected at least one custom-tool permission request, got none")
}
})

t.Run("denies custom tool when permission denied", func(t *testing.T) {
ctx.ConfigureForTest(t)

type EncryptParams struct {
Input string `json:"input" jsonschema:"String to encrypt"`
}

toolHandlerCalled := false

session, err := client.CreateSession(t.Context(), &copilot.SessionConfig{
Tools: []copilot.Tool{
copilot.DefineTool("encrypt_string", "Encrypts a string",
func(params EncryptParams, inv copilot.ToolInvocation) (string, error) {
toolHandlerCalled = true
return strings.ToUpper(params.Input), nil
}),
},
OnPermissionRequest: func(request copilot.PermissionRequest, invocation copilot.PermissionInvocation) (copilot.PermissionRequestResult, error) {
return copilot.PermissionRequestResult{Kind: "denied-interactively-by-user"}, nil
},
})
if err != nil {
t.Fatalf("Failed to create session: %v", err)
}

_, err = session.Send(t.Context(), copilot.MessageOptions{Prompt: "Use encrypt_string to encrypt this string: Hello"})
if err != nil {
t.Fatalf("Failed to send message: %v", err)
}

_, err = testharness.GetFinalAssistantMessage(t.Context(), session)
if err != nil {
t.Fatalf("Failed to get assistant message: %v", err)
}

if toolHandlerCalled {
t.Errorf("Tool handler should NOT have been called since permission was denied")
}
})
}
26 changes: 26 additions & 0 deletions go/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,32 @@ type PermissionRequest struct {
Extra map[string]any `json:"-"` // Additional fields vary by kind
}

// UnmarshalJSON implements custom JSON unmarshaling for PermissionRequest
// to capture additional fields (varying by kind) into the Extra map.
func (p *PermissionRequest) UnmarshalJSON(data []byte) error {
// Unmarshal known fields via an alias to avoid infinite recursion
type Alias PermissionRequest
var alias Alias
if err := json.Unmarshal(data, &alias); err != nil {
return err
}
*p = PermissionRequest(alias)

// Unmarshal all fields into a generic map
var raw map[string]any
if err := json.Unmarshal(data, &raw); err != nil {
return err
}

// Remove known fields, keep the rest as Extra
delete(raw, "kind")
delete(raw, "toolCallId")
if len(raw) > 0 {
p.Extra = raw
}
return nil
}

// PermissionRequestResult represents the result of a permission request
type PermissionRequestResult struct {
Kind string `json:"kind"`
Expand Down
2 changes: 1 addition & 1 deletion nodejs/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ export type SystemMessageConfig = SystemMessageAppendConfig | SystemMessageRepla
* Permission request types from the server
*/
export interface PermissionRequest {
kind: "shell" | "write" | "mcp" | "read" | "url";
kind: "shell" | "write" | "mcp" | "read" | "url" | "custom-tool";
toolCallId?: string;
[key: string]: unknown;
}
Expand Down
63 changes: 63 additions & 0 deletions nodejs/test/e2e/tools.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import { join } from "path";
import { assert, describe, expect, it } from "vitest";
import { z } from "zod";
import { defineTool, approveAll } from "../../src/index.js";
import type { PermissionRequest } from "../../src/index.js";
import { createSdkTestContext } from "./harness/sdkTestContext";

describe("Custom tools", async () => {
Expand Down Expand Up @@ -36,6 +37,7 @@ describe("Custom tools", async () => {
handler: ({ input }) => input.toUpperCase(),
}),
],
onPermissionRequest: approveAll,
});

const assistantMessage = await session.sendAndWait({
Expand All @@ -55,6 +57,7 @@ describe("Custom tools", async () => {
},
}),
],
onPermissionRequest: approveAll,
});

const answer = await session.sendAndWait({
Expand Down Expand Up @@ -111,6 +114,7 @@ describe("Custom tools", async () => {
},
}),
],
onPermissionRequest: approveAll,
});

const assistantMessage = await session.sendAndWait({
Expand All @@ -127,4 +131,63 @@ describe("Custom tools", async () => {
expect(responseContent.replace(/,/g, "")).toContain("135460");
expect(responseContent.replace(/,/g, "")).toContain("204356");
});

it("invokes custom tool with permission handler", async () => {
const permissionRequests: PermissionRequest[] = [];

const session = await client.createSession({
tools: [
defineTool("encrypt_string", {
description: "Encrypts a string",
parameters: z.object({
input: z.string().describe("String to encrypt"),
}),
handler: ({ input }) => input.toUpperCase(),
}),
],
onPermissionRequest: (request) => {
permissionRequests.push(request);
return { kind: "approved" };
},
});

const assistantMessage = await session.sendAndWait({
prompt: "Use encrypt_string to encrypt this string: Hello",
});
expect(assistantMessage?.data.content).toContain("HELLO");

// Should have received a custom-tool permission request
const customToolRequests = permissionRequests.filter((req) => req.kind === "custom-tool");
expect(customToolRequests.length).toBeGreaterThan(0);
expect(customToolRequests[0].toolName).toBe("encrypt_string");
});

it("denies custom tool when permission denied", async () => {
let toolHandlerCalled = false;

const session = await client.createSession({
tools: [
defineTool("encrypt_string", {
description: "Encrypts a string",
parameters: z.object({
input: z.string().describe("String to encrypt"),
}),
handler: ({ input }) => {
toolHandlerCalled = true;
return input.toUpperCase();
},
}),
],
onPermissionRequest: () => {
return { kind: "denied-interactively-by-user" };
},
});

await session.sendAndWait({
prompt: "Use encrypt_string to encrypt this string: Hello",
});

// The tool handler should NOT have been called since permission was denied
expect(toolHandlerCalled).toBe(false);
});
});
2 changes: 1 addition & 1 deletion python/copilot/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ class SystemMessageReplaceConfig(TypedDict):
class PermissionRequest(TypedDict, total=False):
"""Permission request from the server"""

kind: Literal["shell", "write", "mcp", "read", "url"]
kind: Literal["shell", "write", "mcp", "read", "url", "custom-tool"]
toolCallId: str
# Additional fields vary by kind

Expand Down
Loading
Loading