Skip to content

Commit a4ffbf0

Browse files
Improve abort tests across all SDKs; add Go unsubscribe tests (#48)
* Fix session event handler unsubscription and add tests The unsubscribe function was failing due to invalid function pointer comparisons. Refactored handler registration to use unique IDs for reliable cleanup. Tests verify: - Multiple handlers can be registered and all receive events - Unsubscribing one handler doesn't affect others - Calling unsubscribe multiple times is safe - Handlers are called in registration order - Concurrent subscribe/unsubscribe is safe Co-authored-by: nathfavour <116535483+nathfavour@users.noreply.github.com> * Fix abort test to use non-blocking send() * Clean up "should abort a session" * Add equivalent abort test improvements to Go, Python, and .NET * Formatting * More lint/format * Update test_session.py * Fix risk of flakiness --------- Co-authored-by: nathfavour <116535483+nathfavour@users.noreply.github.com>
1 parent d0b15ef commit a4ffbf0

File tree

16 files changed

+409
-63
lines changed

16 files changed

+409
-63
lines changed

dotnet/test/Harness/TestHelper.cs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,4 +73,29 @@ async void CheckExistingMessages()
7373

7474
return null;
7575
}
76+
77+
public static async Task<T> GetNextEventOfTypeAsync<T>(
78+
CopilotSession session,
79+
TimeSpan? timeout = null) where T : SessionEvent
80+
{
81+
var tcs = new TaskCompletionSource<T>();
82+
using var cts = new CancellationTokenSource(timeout ?? TimeSpan.FromSeconds(60));
83+
84+
using var subscription = session.On(evt =>
85+
{
86+
if (evt is T matched)
87+
{
88+
tcs.TrySetResult(matched);
89+
}
90+
else if (evt is SessionErrorEvent error)
91+
{
92+
tcs.TrySetException(new Exception(error.Data.Message ?? "session error"));
93+
}
94+
});
95+
96+
cts.Token.Register(() => tcs.TrySetException(
97+
new TimeoutException($"Timeout waiting for event of type '{typeof(T).Name}'")));
98+
99+
return await tcs.Task;
100+
}
76101
}

dotnet/test/SessionTests.cs

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -201,23 +201,32 @@ public async Task Should_Abort_A_Session()
201201
{
202202
var session = await Client.CreateSessionAsync();
203203

204+
// Set up wait for tool execution to start BEFORE sending
205+
var toolStartTask = TestHelper.GetNextEventOfTypeAsync<ToolExecutionStartEvent>(session);
206+
var sessionIdleTask = TestHelper.GetNextEventOfTypeAsync<SessionIdleEvent>(session);
207+
204208
// Send a message that will take some time to process
205-
await session.SendAsync(new MessageOptions { Prompt = "What is 1+1?" });
209+
await session.SendAsync(new MessageOptions
210+
{
211+
Prompt = "run the shell command 'sleep 100' (note this works on both bash and PowerShell)"
212+
});
213+
214+
// Wait for tool execution to start
215+
await toolStartTask;
206216

207-
// Abort the session immediately
217+
// Abort the session
208218
await session.AbortAsync();
219+
await sessionIdleTask;
209220

210221
// The session should still be alive and usable after abort
211222
var messages = await session.GetMessagesAsync();
212223
Assert.NotEmpty(messages);
213224

214-
// TODO: We should do something to verify it really did abort (e.g., is there an abort event we can see,
215-
// or can we check that the session became idle without receiving an assistant message?). Right now
216-
// I'm not seeing any evidence that it actually does abort.
225+
// Verify an abort event exists in messages
226+
Assert.Contains(messages, m => m is AbortEvent);
217227

218228
// We should be able to send another message
219-
await session.SendAsync(new MessageOptions { Prompt = "What is 2+2?" });
220-
var answer = await TestHelper.GetFinalAssistantMessageAsync(session);
229+
var answer = await session.SendAndWaitAsync(new MessageOptions { Prompt = "What is 2+2?" });
221230
Assert.NotNull(answer);
222231
Assert.Contains("4", answer!.Data.Content ?? string.Empty);
223232
}

go/client.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,6 @@ import (
3939
"strings"
4040
"sync"
4141
"time"
42-
43-
"github.com/github/copilot-sdk/go/generated"
4442
)
4543

4644
// Client manages the connection to the Copilot CLI server and provides session management.
@@ -923,7 +921,7 @@ func (c *Client) setupNotificationHandler() {
923921
return
924922
}
925923

926-
event, err := generated.UnmarshalSessionEvent(eventJSON)
924+
event, err := UnmarshalSessionEvent(eventJSON)
927925
if err != nil {
928926
return
929927
}

go/e2e/session_test.go

Lines changed: 55 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -472,18 +472,57 @@ func TestSession(t *testing.T) {
472472
t.Fatalf("Failed to create session: %v", err)
473473
}
474474

475-
// Send a message that will take some time to process
476-
_, err = session.Send(copilot.MessageOptions{Prompt: "What is 1+1?"})
475+
// Set up event listeners BEFORE sending to avoid race conditions
476+
toolStartCh := make(chan *copilot.SessionEvent, 1)
477+
toolStartErrCh := make(chan error, 1)
478+
go func() {
479+
evt, err := testharness.GetNextEventOfType(session, copilot.ToolExecutionStart, 60*time.Second)
480+
if err != nil {
481+
toolStartErrCh <- err
482+
} else {
483+
toolStartCh <- evt
484+
}
485+
}()
486+
487+
sessionIdleCh := make(chan *copilot.SessionEvent, 1)
488+
sessionIdleErrCh := make(chan error, 1)
489+
go func() {
490+
evt, err := testharness.GetNextEventOfType(session, copilot.SessionIdle, 60*time.Second)
491+
if err != nil {
492+
sessionIdleErrCh <- err
493+
} else {
494+
sessionIdleCh <- evt
495+
}
496+
}()
497+
498+
// Send a message that triggers a long-running shell command
499+
_, err = session.Send(copilot.MessageOptions{Prompt: "run the shell command 'sleep 100' (note this works on both bash and PowerShell)"})
477500
if err != nil {
478501
t.Fatalf("Failed to send message: %v", err)
479502
}
480503

481-
// Abort the session immediately
504+
// Wait for tool.execution_start
505+
select {
506+
case <-toolStartCh:
507+
// Tool execution has started
508+
case err := <-toolStartErrCh:
509+
t.Fatalf("Failed waiting for tool.execution_start: %v", err)
510+
}
511+
512+
// Abort the session
482513
err = session.Abort()
483514
if err != nil {
484515
t.Fatalf("Failed to abort session: %v", err)
485516
}
486517

518+
// Wait for session.idle after abort
519+
select {
520+
case <-sessionIdleCh:
521+
// Session is idle
522+
case err := <-sessionIdleErrCh:
523+
t.Fatalf("Failed waiting for session.idle after abort: %v", err)
524+
}
525+
487526
// The session should still be alive and usable after abort
488527
messages, err := session.GetMessages()
489528
if err != nil {
@@ -493,15 +532,22 @@ func TestSession(t *testing.T) {
493532
t.Error("Expected messages to exist after abort")
494533
}
495534

496-
// We should be able to send another message
497-
_, err = session.Send(copilot.MessageOptions{Prompt: "What is 2+2?"})
498-
if err != nil {
499-
t.Fatalf("Failed to send message after abort: %v", err)
535+
// Verify messages contain an abort event
536+
hasAbortEvent := false
537+
for _, msg := range messages {
538+
if msg.Type == copilot.Abort {
539+
hasAbortEvent = true
540+
break
541+
}
542+
}
543+
if !hasAbortEvent {
544+
t.Error("Expected messages to contain an 'abort' event")
500545
}
501546

502-
answer, err := testharness.GetFinalAssistantMessage(session, 60*time.Second)
547+
// We should be able to send another message
548+
answer, err := session.SendAndWait(copilot.MessageOptions{Prompt: "What is 2+2?"}, 60*time.Second)
503549
if err != nil {
504-
t.Fatalf("Failed to get assistant message after abort: %v", err)
550+
t.Fatalf("Failed to send message after abort: %v", err)
505551
}
506552

507553
if answer.Data.Content == nil || !strings.Contains(*answer.Data.Content, "4") {

go/e2e/testharness/helper.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,41 @@ func GetFinalAssistantMessage(session *copilot.Session, timeout time.Duration) (
5454
}
5555
}
5656

57+
// GetNextEventOfType waits for and returns the next event of the specified type from a session.
58+
func GetNextEventOfType(session *copilot.Session, eventType copilot.SessionEventType, timeout time.Duration) (*copilot.SessionEvent, error) {
59+
result := make(chan *copilot.SessionEvent, 1)
60+
errCh := make(chan error, 1)
61+
62+
unsubscribe := session.On(func(event copilot.SessionEvent) {
63+
switch event.Type {
64+
case eventType:
65+
select {
66+
case result <- &event:
67+
default:
68+
}
69+
case copilot.SessionError:
70+
msg := "session error"
71+
if event.Data.Message != nil {
72+
msg = *event.Data.Message
73+
}
74+
select {
75+
case errCh <- errors.New(msg):
76+
default:
77+
}
78+
}
79+
})
80+
defer unsubscribe()
81+
82+
select {
83+
case evt := <-result:
84+
return evt, nil
85+
case err := <-errCh:
86+
return nil, err
87+
case <-time.After(timeout):
88+
return nil, errors.New("timeout waiting for event: " + string(eventType))
89+
}
90+
}
91+
5792
func getExistingFinalResponse(session *copilot.Session) (*copilot.SessionEvent, error) {
5893
messages, err := session.GetMessages()
5994
if err != nil {
Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

go/session.go

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@ import (
66
"fmt"
77
"sync"
88
"time"
9-
10-
"github.com/github/copilot-sdk/go/generated"
119
)
1210

1311
type sessionHandler struct {
@@ -159,17 +157,17 @@ func (s *Session) SendAndWait(options MessageOptions, timeout time.Duration) (*S
159157

160158
unsubscribe := s.On(func(event SessionEvent) {
161159
switch event.Type {
162-
case generated.AssistantMessage:
160+
case AssistantMessage:
163161
mu.Lock()
164162
eventCopy := event
165163
lastAssistantMessage = &eventCopy
166164
mu.Unlock()
167-
case generated.SessionIdle:
165+
case SessionIdle:
168166
select {
169167
case idleCh <- struct{}{}:
170168
default:
171169
}
172-
case generated.SessionError:
170+
case SessionError:
173171
errMsg := "session error"
174172
if event.Data.Message != nil {
175173
errMsg = *event.Data.Message
@@ -387,7 +385,7 @@ func (s *Session) GetMessages() ([]SessionEvent, error) {
387385
continue
388386
}
389387

390-
event, err := generated.UnmarshalSessionEvent(eventJSON)
388+
event, err := UnmarshalSessionEvent(eventJSON)
391389
if err != nil {
392390
continue
393391
}

go/session_test.go

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
package copilot
2+
3+
import (
4+
"sync"
5+
"testing"
6+
)
7+
8+
func TestSession_On(t *testing.T) {
9+
t.Run("multiple handlers all receive events", func(t *testing.T) {
10+
session := &Session{
11+
handlers: make([]sessionHandler, 0),
12+
}
13+
14+
var received1, received2, received3 bool
15+
session.On(func(event SessionEvent) { received1 = true })
16+
session.On(func(event SessionEvent) { received2 = true })
17+
session.On(func(event SessionEvent) { received3 = true })
18+
19+
session.dispatchEvent(SessionEvent{Type: "test"})
20+
21+
if !received1 || !received2 || !received3 {
22+
t.Errorf("Expected all handlers to receive event, got received1=%v, received2=%v, received3=%v",
23+
received1, received2, received3)
24+
}
25+
})
26+
27+
t.Run("unsubscribing one handler does not affect others", func(t *testing.T) {
28+
session := &Session{
29+
handlers: make([]sessionHandler, 0),
30+
}
31+
32+
var count1, count2, count3 int
33+
session.On(func(event SessionEvent) { count1++ })
34+
unsub2 := session.On(func(event SessionEvent) { count2++ })
35+
session.On(func(event SessionEvent) { count3++ })
36+
37+
// First event - all handlers receive it
38+
session.dispatchEvent(SessionEvent{Type: "test"})
39+
40+
// Unsubscribe handler 2
41+
unsub2()
42+
43+
// Second event - only handlers 1 and 3 should receive it
44+
session.dispatchEvent(SessionEvent{Type: "test"})
45+
46+
if count1 != 2 {
47+
t.Errorf("Expected handler 1 to receive 2 events, got %d", count1)
48+
}
49+
if count2 != 1 {
50+
t.Errorf("Expected handler 2 to receive 1 event (before unsubscribe), got %d", count2)
51+
}
52+
if count3 != 2 {
53+
t.Errorf("Expected handler 3 to receive 2 events, got %d", count3)
54+
}
55+
})
56+
57+
t.Run("calling unsubscribe multiple times is safe", func(t *testing.T) {
58+
session := &Session{
59+
handlers: make([]sessionHandler, 0),
60+
}
61+
62+
var count int
63+
unsub := session.On(func(event SessionEvent) { count++ })
64+
65+
session.dispatchEvent(SessionEvent{Type: "test"})
66+
67+
// Call unsubscribe multiple times - should not panic
68+
unsub()
69+
unsub()
70+
unsub()
71+
72+
session.dispatchEvent(SessionEvent{Type: "test"})
73+
74+
if count != 1 {
75+
t.Errorf("Expected handler to receive 1 event, got %d", count)
76+
}
77+
})
78+
79+
t.Run("handlers are called in registration order", func(t *testing.T) {
80+
session := &Session{
81+
handlers: make([]sessionHandler, 0),
82+
}
83+
84+
var order []int
85+
session.On(func(event SessionEvent) { order = append(order, 1) })
86+
session.On(func(event SessionEvent) { order = append(order, 2) })
87+
session.On(func(event SessionEvent) { order = append(order, 3) })
88+
89+
session.dispatchEvent(SessionEvent{Type: "test"})
90+
91+
if len(order) != 3 || order[0] != 1 || order[1] != 2 || order[2] != 3 {
92+
t.Errorf("Expected handlers to be called in order [1,2,3], got %v", order)
93+
}
94+
})
95+
96+
t.Run("concurrent subscribe and unsubscribe is safe", func(t *testing.T) {
97+
session := &Session{
98+
handlers: make([]sessionHandler, 0),
99+
}
100+
101+
var wg sync.WaitGroup
102+
for i := 0; i < 100; i++ {
103+
wg.Add(1)
104+
go func() {
105+
defer wg.Done()
106+
unsub := session.On(func(event SessionEvent) {})
107+
unsub()
108+
}()
109+
}
110+
wg.Wait()
111+
112+
// Should not panic and handlers should be empty
113+
session.handlerMutex.RLock()
114+
count := len(session.handlers)
115+
session.handlerMutex.RUnlock()
116+
117+
if count != 0 {
118+
t.Errorf("Expected 0 handlers after all unsubscribes, got %d", count)
119+
}
120+
})
121+
}

0 commit comments

Comments
 (0)