diff --git a/README.md b/README.md index 278eadf..c247ea6 100644 --- a/README.md +++ b/README.md @@ -75,6 +75,52 @@ if err := result.Wait(); err != nil { } ``` +### Streaming callbacks + +Set `TurnOptions.Callbacks` to receive typed updates without writing a `switch` over +`ThreadEvent`. The SDK invokes `OnEvent` first, followed by any matching typed callbacks, +and finally forwards the raw event through `Events()`. Callbacks run on the streaming goroutine, +so long-running work should be offloaded to avoid stalling the stream. You must continue +draining the `Events()` channel (an empty `for range` loop works) to honour the CLI's +backpressure expectations. + +```go +callbacks := &godex.StreamCallbacks{ + OnMessage: func(evt godex.StreamMessageEvent) { + if evt.Stage == godex.StreamItemStageCompleted { + log.Printf("assistant: %s", evt.Message.Text) + } + }, + OnCommand: func(evt godex.StreamCommandEvent) { + log.Printf("command %s: %s", evt.Command.Status, evt.Command.Command) + }, + OnPatch: func(evt godex.StreamPatchEvent) { + log.Printf("patch %s: %s", evt.Patch.ID, evt.Patch.Status) + }, + OnFileChange: func(evt godex.StreamFileChangeEvent) { + log.Printf(" file %s (%s)", evt.Change.Path, evt.Change.Kind) + }, +} + +result, err := thread.RunStreamed(ctx, "Summarize the latest changes.", &godex.TurnOptions{ + Callbacks: callbacks, +}) +if err != nil { + log.Fatal(err) +} +defer result.Close() + +for range result.Events() { + // Drain events; callbacks handled the typed work already. +} + +if err := result.Wait(); err != nil { + log.Fatal(err) +} +``` + +See `examples/streaming_callbacks` for a complete runnable sample. + ## Structured output Pass a JSON schema in `TurnOptions.OutputSchema` and the SDK writes a temporary file for the CLI: diff --git a/callbacks.go b/callbacks.go new file mode 100644 index 0000000..4771943 --- /dev/null +++ b/callbacks.go @@ -0,0 +1,176 @@ +package godex + +// StreamItemStage indicates which phase of the lifecycle produced a callback. +type StreamItemStage string + +const ( + StreamItemStageStarted StreamItemStage = "started" + StreamItemStageUpdated StreamItemStage = "updated" + StreamItemStageCompleted StreamItemStage = "completed" +) + +// StreamMessageEvent describes a callback payload for agent message items. +type StreamMessageEvent struct { + Stage StreamItemStage + Message AgentMessageItem +} + +// StreamReasoningEvent describes a callback payload for reasoning items. +type StreamReasoningEvent struct { + Stage StreamItemStage + Reasoning ReasoningItem +} + +// StreamCommandEvent describes a callback payload for command execution items. +type StreamCommandEvent struct { + Stage StreamItemStage + Command CommandExecutionItem +} + +// StreamPatchEvent describes a callback payload for patch/file change items. +type StreamPatchEvent struct { + Stage StreamItemStage + Patch FileChangeItem +} + +// StreamFileChangeEvent describes a callback payload for each file updated within a patch. +type StreamFileChangeEvent struct { + Stage StreamItemStage + Patch FileChangeItem + Change FileUpdateChange +} + +// StreamWebSearchEvent describes a callback payload for web search items. +type StreamWebSearchEvent struct { + Stage StreamItemStage + Search WebSearchItem +} + +// StreamToolCallEvent describes a callback payload for MCP tool call items. +type StreamToolCallEvent struct { + Stage StreamItemStage + ToolCall McpToolCallItem +} + +// StreamTodoListEvent describes a callback payload for todo list items. +type StreamTodoListEvent struct { + Stage StreamItemStage + List TodoListItem +} + +// StreamErrorItemEvent describes a callback payload for non-fatal error items. +type StreamErrorItemEvent struct { + Stage StreamItemStage + Error ErrorItem +} + +// StreamCallbacks enumerates optional hooks invoked when streaming events are delivered. +type StreamCallbacks struct { + // OnEvent fires for every event before any type-specific callback. + OnEvent func(ThreadEvent) + + OnThreadStarted func(ThreadStartedEvent) + OnTurnStarted func(TurnStartedEvent) + OnTurnCompleted func(TurnCompletedEvent) + OnTurnFailed func(TurnFailedEvent) + OnThreadError func(ThreadErrorEvent) + + OnMessage func(StreamMessageEvent) + OnReasoning func(StreamReasoningEvent) + OnCommand func(StreamCommandEvent) + OnPatch func(StreamPatchEvent) + OnFileChange func(StreamFileChangeEvent) + OnWebSearch func(StreamWebSearchEvent) + OnToolCall func(StreamToolCallEvent) + OnTodoList func(StreamTodoListEvent) + OnErrorItem func(StreamErrorItemEvent) +} + +func (c *StreamCallbacks) handle(event ThreadEvent) { + if c == nil { + return + } + + if c.OnEvent != nil { + c.OnEvent(event) + } + + switch e := event.(type) { + case ThreadStartedEvent: + if c.OnThreadStarted != nil { + c.OnThreadStarted(e) + } + case TurnStartedEvent: + if c.OnTurnStarted != nil { + c.OnTurnStarted(e) + } + case TurnCompletedEvent: + if c.OnTurnCompleted != nil { + c.OnTurnCompleted(e) + } + case TurnFailedEvent: + if c.OnTurnFailed != nil { + c.OnTurnFailed(e) + } + case ThreadErrorEvent: + if c.OnThreadError != nil { + c.OnThreadError(e) + } + case ItemStartedEvent: + c.handleItem(StreamItemStageStarted, e.Item) + case ItemUpdatedEvent: + c.handleItem(StreamItemStageUpdated, e.Item) + case ItemCompletedEvent: + c.handleItem(StreamItemStageCompleted, e.Item) + } +} + +func (c *StreamCallbacks) handleItem(stage StreamItemStage, item ThreadItem) { + if c == nil || item == nil { + return + } + + switch v := item.(type) { + case AgentMessageItem: + if c.OnMessage != nil { + c.OnMessage(StreamMessageEvent{Stage: stage, Message: v}) + } + case ReasoningItem: + if c.OnReasoning != nil { + c.OnReasoning(StreamReasoningEvent{Stage: stage, Reasoning: v}) + } + case CommandExecutionItem: + if c.OnCommand != nil { + c.OnCommand(StreamCommandEvent{Stage: stage, Command: v}) + } + case FileChangeItem: + if c.OnPatch != nil { + c.OnPatch(StreamPatchEvent{Stage: stage, Patch: v}) + } + if c.OnFileChange != nil { + for _, change := range v.Changes { + c.OnFileChange(StreamFileChangeEvent{ + Stage: stage, + Patch: v, + Change: change, + }) + } + } + case McpToolCallItem: + if c.OnToolCall != nil { + c.OnToolCall(StreamToolCallEvent{Stage: stage, ToolCall: v}) + } + case WebSearchItem: + if c.OnWebSearch != nil { + c.OnWebSearch(StreamWebSearchEvent{Stage: stage, Search: v}) + } + case TodoListItem: + if c.OnTodoList != nil { + c.OnTodoList(StreamTodoListEvent{Stage: stage, List: v}) + } + case ErrorItem: + if c.OnErrorItem != nil { + c.OnErrorItem(StreamErrorItemEvent{Stage: stage, Error: v}) + } + } +} diff --git a/examples/streaming_callbacks/main.go b/examples/streaming_callbacks/main.go new file mode 100644 index 0000000..72b2532 --- /dev/null +++ b/examples/streaming_callbacks/main.go @@ -0,0 +1,62 @@ +package main + +import ( + "context" + "fmt" + "log" + + "github.com/activadee/godex" +) + +func main() { + client, err := godex.New(godex.CodexOptions{}) + if err != nil { + log.Fatalf("create codex client: %v", err) + } + + thread := client.StartThread(godex.ThreadOptions{ + Model: "gpt-5", + }) + + callbacks := &godex.StreamCallbacks{ + OnMessage: func(evt godex.StreamMessageEvent) { + switch evt.Stage { + case godex.StreamItemStageUpdated: + fmt.Printf("[assistant partial] %s\n", evt.Message.Text) + case godex.StreamItemStageCompleted: + fmt.Printf("[assistant final] %s\n", evt.Message.Text) + } + }, + OnCommand: func(evt godex.StreamCommandEvent) { + fmt.Printf("[command %s] %s\n", evt.Command.Status, evt.Command.Command) + }, + OnPatch: func(evt godex.StreamPatchEvent) { + fmt.Printf("[patch %s] status=%s\n", evt.Patch.ID, evt.Patch.Status) + }, + OnFileChange: func(evt godex.StreamFileChangeEvent) { + fmt.Printf(" file %s (%s)\n", evt.Change.Path, evt.Change.Kind) + }, + OnWebSearch: func(evt godex.StreamWebSearchEvent) { + fmt.Printf("[web search] %s\n", evt.Search.Query) + }, + OnThreadError: func(evt godex.ThreadErrorEvent) { + log.Printf("[stream error] %s", evt.Message) + }, + } + + result, err := thread.RunStreamed(context.Background(), "Summarize the latest SDK changes and list next steps.", &godex.TurnOptions{ + Callbacks: callbacks, + }) + if err != nil { + log.Fatalf("start streamed run: %v", err) + } + defer result.Close() + + for range result.Events() { + // Drain events to honour backpressure; callbacks already handled rendering. + } + + if err := result.Wait(); err != nil { + log.Fatalf("stream failed: %v", err) + } +} diff --git a/options.go b/options.go index 5cd52a8..9c2a8a7 100644 --- a/options.go +++ b/options.go @@ -53,4 +53,6 @@ type TurnOptions struct { // OutputSchema is an optional JSON schema describing the structured response to // collect from the agent. Must serialize to a JSON object (not an array or primitive). OutputSchema any + // Callbacks attaches optional streaming callbacks invoked as events arrive. + Callbacks *StreamCallbacks } diff --git a/thread.go b/thread.go index 2d2ffda..4beabbe 100644 --- a/thread.go +++ b/thread.go @@ -102,6 +102,8 @@ func (t *Thread) runStreamed(ctx context.Context, baseInput string, segments []I turnOpts = *turnOptions } + callbacks := turnOpts.Callbacks + prepared, err := normalizeInput(baseInput, segments) if err != nil { return RunStreamedResult{}, err @@ -152,6 +154,10 @@ func (t *Thread) runStreamed(ctx context.Context, baseInput string, segments []I threadErr = &ThreadStreamError{ThreadError: ThreadError{Message: errEvent.Message}} } + if callbacks != nil { + callbacks.handle(event) + } + select { case events <- event: return nil diff --git a/thread_stream_test.go b/thread_stream_test.go index bef6e96..d45fdff 100644 --- a/thread_stream_test.go +++ b/thread_stream_test.go @@ -2,6 +2,7 @@ package godex import ( "context" + "sync" "testing" ) @@ -66,3 +67,136 @@ func TestThreadRunStreamedInputsForwardsImages(t *testing.T) { t.Fatalf("unexpected images slice: %v", call.Images) } } + +func TestStreamCallbacksDispatchTypedItems(t *testing.T) { + events := marshalEvents(t, []map[string]any{ + {"type": "thread.started", "thread_id": "thread_1"}, + {"type": "item.updated", "item": map[string]any{ + "id": "message_1", + "type": "agent_message", + "text": "partial: hello", + }}, + {"type": "item.updated", "item": map[string]any{ + "id": "command_1", + "type": "command_execution", + "command": "go test ./...", + "aggregated_output": "running tests", + "status": "in_progress", + }}, + {"type": "item.completed", "item": map[string]any{ + "id": "patch_1", + "type": "file_change", + "status": "completed", + "changes": []map[string]any{ + {"path": "main.go", "kind": "update"}, + {"path": "README.md", "kind": "update"}, + }, + }}, + {"type": "item.completed", "item": map[string]any{ + "id": "search_1", + "type": "web_search", + "query": "godex callbacks", + }}, + {"type": "turn.completed", "usage": map[string]any{"input_tokens": 1, "cached_input_tokens": 0, "output_tokens": 1}}, + }) + + runner := &fakeRunner{t: t, batches: []fakeRun{{events: events}}} + thread := newThread(runner, CodexOptions{}, ThreadOptions{}, "") + + var ( + mu sync.Mutex + messages []StreamMessageEvent + commands []StreamCommandEvent + patches []StreamPatchEvent + fileChanges []StreamFileChangeEvent + webSearches []StreamWebSearchEvent + ) + + callbacks := &StreamCallbacks{ + OnMessage: func(evt StreamMessageEvent) { + mu.Lock() + defer mu.Unlock() + messages = append(messages, evt) + }, + OnCommand: func(evt StreamCommandEvent) { + mu.Lock() + defer mu.Unlock() + commands = append(commands, evt) + }, + OnPatch: func(evt StreamPatchEvent) { + mu.Lock() + defer mu.Unlock() + patches = append(patches, evt) + }, + OnFileChange: func(evt StreamFileChangeEvent) { + mu.Lock() + defer mu.Unlock() + fileChanges = append(fileChanges, evt) + }, + OnWebSearch: func(evt StreamWebSearchEvent) { + mu.Lock() + defer mu.Unlock() + webSearches = append(webSearches, evt) + }, + } + + result, err := thread.RunStreamed(context.Background(), "callbacks please", &TurnOptions{Callbacks: callbacks}) + if err != nil { + t.Fatalf("RunStreamed returned error: %v", err) + } + defer result.Close() + + for range result.Events() { + // Drain events while callbacks handle type-specific logic. + } + + if err := result.Wait(); err != nil { + t.Fatalf("result.Wait returned error: %v", err) + } + + mu.Lock() + defer mu.Unlock() + + if len(messages) != 1 { + t.Fatalf("expected 1 message callback, got %d", len(messages)) + } + if messages[0].Stage != StreamItemStageUpdated || messages[0].Message.Text != "partial: hello" { + t.Fatalf("unexpected message callback payload: %+v", messages[0]) + } + + if len(commands) != 1 { + t.Fatalf("expected 1 command callback, got %d", len(commands)) + } + if commands[0].Stage != StreamItemStageUpdated || commands[0].Command.Command != "go test ./..." { + t.Fatalf("unexpected command callback payload: %+v", commands[0]) + } + + if len(patches) != 1 { + t.Fatalf("expected 1 patch callback, got %d", len(patches)) + } + if patches[0].Stage != StreamItemStageCompleted || patches[0].Patch.ID != "patch_1" { + t.Fatalf("unexpected patch callback payload: %+v", patches[0]) + } + + if len(fileChanges) != 2 { + t.Fatalf("expected 2 file change callbacks, got %d", len(fileChanges)) + } + if fileChanges[0].Patch.ID != "patch_1" || fileChanges[0].Change.Path != "main.go" { + t.Fatalf("unexpected first file change payload: %+v", fileChanges[0]) + } + if fileChanges[1].Patch.ID != "patch_1" || fileChanges[1].Change.Path != "README.md" { + t.Fatalf("unexpected second file change payload: %+v", fileChanges[1]) + } + for _, change := range fileChanges { + if change.Stage != StreamItemStageCompleted { + t.Fatalf("expected completed stage for file change, got %+v", change.Stage) + } + } + + if len(webSearches) != 1 { + t.Fatalf("expected 1 web search callback, got %d", len(webSearches)) + } + if webSearches[0].Stage != StreamItemStageCompleted || webSearches[0].Search.Query != "godex callbacks" { + t.Fatalf("unexpected web search callback payload: %+v", webSearches[0]) + } +}