diff --git a/README.md b/README.md index b5b7154..318a80e 100644 --- a/README.md +++ b/README.md @@ -93,6 +93,24 @@ turn, err := thread.Run(ctx, "Write a one sentence update.", &godex.TurnOptions{ }) ``` +### Typed helpers + +Generate and decode structured JSON into Go types with `RunJSON` / `RunStreamedJSON`. Provide +your own schema or allow the helpers to infer one from `T`: + +```go +type Update struct { + Headline string `json:"headline"` + NextStep string `json:"next_step"` +} + +result, err := godex.RunJSON[Update](ctx, thread, "Provide a concise update.", nil) +if err != nil { + log.Fatal(err) +} +log.Printf("update: %+v", result) +``` + ## Multi-part input and local images Mix text segments and local image paths by using `RunInputs` / `RunStreamedInputs` with @@ -119,6 +137,7 @@ fmt.Println("Assistant:", turn.FinalResponse) - `examples/basic`: single-turn conversation (`go run ./examples/basic`) - `examples/streaming`: step-by-step event streaming demo (`go run ./examples/streaming`) - `examples/schema`: structured JSON output with schema validation (`go run ./examples/schema`) +- `examples/structured_output`: typed structured output helpers (`go run ./examples/structured_output`) - `examples/images`: multi-part prompt mixing text and a local image (`go run ./examples/images`) ## Thread persistence diff --git a/examples/structured_output/main.go b/examples/structured_output/main.go new file mode 100644 index 0000000..943e0e1 --- /dev/null +++ b/examples/structured_output/main.go @@ -0,0 +1,46 @@ +package main + +import ( + "context" + "fmt" + "log" + + "github.com/activadee/godex" +) + +type projectUpdate struct { + Headline string `json:"headline" jsonschema:"description=Short summary of the update"` + NextStep string `json:"next_step" jsonschema:"description=Concrete follow-up action"` +} + +func main() { + client, err := godex.New(godex.CodexOptions{}) + if err != nil { + log.Fatalf("create codex client: %v", err) + } + + thread := client.StartThread(godex.ThreadOptions{ + Model: "gpt-5", + }) + + update, err := godex.RunJSON[projectUpdate](context.Background(), thread, "Provide a concise project update and a suggested next step.", nil) + if err != nil { + log.Fatalf("run structured turn: %v", err) + } + + fmt.Printf("Headline: %s\nNext step: %s\n", update.Headline, update.NextStep) + + streamed, err := godex.RunStreamedJSON[projectUpdate](context.Background(), thread, "Give another update and next step, streaming partial results.", nil) + if err != nil { + log.Fatalf("start streamed structured turn: %v", err) + } + defer streamed.Close() + + for update := range streamed.Updates() { + fmt.Printf("[structured update] final=%t headline=%q next_step=%q\n", update.Final, update.Value.Headline, update.Value.NextStep) + } + + if err := streamed.Wait(); err != nil { + log.Fatalf("streamed structured turn failed: %v", err) + } +} diff --git a/go.mod b/go.mod index 98e59f3..a512435 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,13 @@ module github.com/activadee/godex go 1.22 + +require github.com/invopop/jsonschema v0.13.0 + +require ( + github.com/bahlo/generic-list-go v0.2.0 // indirect + github.com/buger/jsonparser v1.1.1 // indirect + github.com/mailru/easyjson v0.7.7 // indirect + github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..38ac45c --- /dev/null +++ b/go.sum @@ -0,0 +1,21 @@ +github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= +github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg= +github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= +github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E= +github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= +github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= +github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= +github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= +github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/structured_output.go b/structured_output.go new file mode 100644 index 0000000..9489fd6 --- /dev/null +++ b/structured_output.go @@ -0,0 +1,361 @@ +package godex + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "reflect" + "strings" + "sync" + + "github.com/invopop/jsonschema" +) + +var ( + // ErrNoStructuredOutput indicates that the turn completed without returning a structured + // response that could be decoded into the requested type. + ErrNoStructuredOutput = errors.New("structured output not returned") +) + +const runStreamedJSONEventBuffer = 16 + +// RunJSONOptions configure a typed JSON turn. +type RunJSONOptions[T any] struct { + // TurnOptions forwards additional options for the turn. When nil a zero TurnOptions + // value is used. + TurnOptions *TurnOptions + // Schema provides an explicit JSON schema for the structured output. When nil the + // helper attempts schema inference unless DisableSchemaInference is true. + Schema any + // DisableSchemaInference prevents automatic schema inference from T when Schema is nil. + DisableSchemaInference bool +} + +// SchemaViolationError indicates that the structured output failed schema validation. +type SchemaViolationError struct { + Message string +} + +// Error implements the error interface. +func (e *SchemaViolationError) Error() string { + if e == nil || e.Message == "" { + return "structured output schema violation" + } + return e.Message +} + +// RunJSON executes a turn expecting a structured JSON response that can be decoded into T. +func RunJSON[T any](ctx context.Context, thread *Thread, input string, options *RunJSONOptions[T]) (T, error) { + var zero T + + if thread == nil { + return zero, errors.New("RunJSON requires a non-nil thread") + } + + config, err := prepareRunJSONOptions[T](options) + if err != nil { + return zero, err + } + + result, err := thread.run(ctx, input, nil, &config.turnOptions) + if err != nil { + if schemaErr, ok := classifyStructuredOutputError(err, config.expectSchemaError); ok { + return zero, schemaErr + } + return zero, err + } + + var value T + if err := json.Unmarshal([]byte(result.FinalResponse), &value); err != nil { + return zero, fmt.Errorf("decode structured output: %w", err) + } + return value, nil +} + +// RunStreamedJSONUpdate captures a typed snapshot of the structured output as the turn progresses. +type RunStreamedJSONUpdate[T any] struct { + Value T + Raw string + Final bool +} + +// RunStreamedJSONResult exposes the streaming lifecycle for a typed structured output turn. +type RunStreamedJSONResult[T any] struct { + stream *Stream + events <-chan ThreadEvent + updates <-chan RunStreamedJSONUpdate[T] + err *sharedError + done <-chan struct{} +} + +// Events returns the stream of raw thread events produced by the turn. +func (r RunStreamedJSONResult[T]) Events() <-chan ThreadEvent { + return r.events +} + +// Updates yields typed structured output snapshots. The channel closes once the turn finishes. +func (r RunStreamedJSONResult[T]) Updates() <-chan RunStreamedJSONUpdate[T] { + return r.updates +} + +// Wait blocks until the turn finishes and returns the terminal error, if any. +func (r RunStreamedJSONResult[T]) Wait() error { + var done <-chan struct{} + if r.done != nil { + done = r.done + } + if r.stream == nil { + if done != nil { + <-done + } + if r.err != nil { + return r.err.get() + } + return nil + } + if err := r.stream.Wait(); err != nil { + if done != nil { + <-done + } + return err + } + if done != nil { + <-done + } + if r.err != nil { + return r.err.get() + } + return nil +} + +// Close cancels the turn and waits for shutdown. +func (r RunStreamedJSONResult[T]) Close() error { + var done <-chan struct{} + if r.done != nil { + done = r.done + } + if r.stream == nil { + if done != nil { + <-done + } + if r.err != nil { + return r.err.get() + } + return nil + } + if err := r.stream.Close(); err != nil { + if done != nil { + <-done + } + return err + } + if done != nil { + <-done + } + if r.err != nil { + return r.err.get() + } + return nil +} + +// RunStreamedJSON executes a turn expecting structured JSON output and streams raw events +// alongside typed snapshots decoded into T. +func RunStreamedJSON[T any](ctx context.Context, thread *Thread, input string, options *RunJSONOptions[T]) (RunStreamedJSONResult[T], error) { + config, err := prepareRunJSONOptions[T](options) + if err != nil { + return RunStreamedJSONResult[T]{}, err + } + + if thread == nil { + return RunStreamedJSONResult[T]{}, errors.New("RunStreamedJSON requires a non-nil thread") + } + + raw, err := thread.runStreamed(ctx, input, nil, &config.turnOptions) + if err != nil { + return RunStreamedJSONResult[T]{}, err + } + + events := make(chan ThreadEvent, runStreamedJSONEventBuffer) + updates := make(chan RunStreamedJSONUpdate[T], runStreamedJSONEventBuffer) + shErr := &sharedError{} + fanoutDone := make(chan struct{}) + + result := RunStreamedJSONResult[T]{ + stream: raw.stream, + events: events, + updates: updates, + err: shErr, + done: fanoutDone, + } + + go func() { + defer close(fanoutDone) + defer close(events) + defer close(updates) + + var deliveredFinal bool + var turnCompleted bool + + for event := range raw.Events() { + switch e := event.(type) { + case ItemUpdatedEvent: + if msg, ok := e.Item.(AgentMessageItem); ok { + if update, decodeErr := decodeStructuredMessage[T](msg, false); decodeErr == nil { + select { + case updates <- update: + case <-raw.stream.done: + return + default: + // Drop intermediate snapshot when the consumer ignores updates. + } + } + } + case ItemCompletedEvent: + if msg, ok := e.Item.(AgentMessageItem); ok { + update, decodeErr := decodeStructuredMessage[T](msg, true) + if decodeErr != nil { + shErr.set(decodeErr) + } else { + deliveredFinal = true + select { + case updates <- update: + case <-raw.stream.done: + return + default: + // Drop final snapshot when the consumer ignores updates. + } + } + } + case TurnCompletedEvent: + turnCompleted = true + case TurnFailedEvent: + rawErr := errors.New(e.Error.Message) + if schemaErr, ok := classifyStructuredOutputError(rawErr, config.expectSchemaError); ok { + shErr.set(schemaErr) + } else { + shErr.set(rawErr) + } + } + + select { + case events <- event: + default: + // Drop events when no consumer is attached to avoid blocking snapshot updates. + } + } + + if turnCompleted && !deliveredFinal { + shErr.set(ErrNoStructuredOutput) + } + }() + + return result, nil +} + +type runJSONConfig struct { + turnOptions TurnOptions + expectSchemaError bool +} + +func prepareRunJSONOptions[T any](options *RunJSONOptions[T]) (runJSONConfig, error) { + var config runJSONConfig + + if options != nil && options.TurnOptions != nil { + config.turnOptions = *options.TurnOptions + } + + var schema any + if options != nil && options.Schema != nil { + schema = options.Schema + } else if config.turnOptions.OutputSchema != nil { + schema = config.turnOptions.OutputSchema + } else if options == nil || !options.DisableSchemaInference { + inferred, err := inferSchemaForType[T]() + if err != nil { + return config, err + } + schema = inferred + config.expectSchemaError = true + } else { + return config, errors.New("RunJSON requires a schema; provide RunJSONOptions.Schema or TurnOptions.OutputSchema") + } + + if schema == nil { + return config, errors.New("RunJSON resolved nil schema") + } + + config.turnOptions.OutputSchema = schema + if !config.expectSchemaError && schema != nil { + config.expectSchemaError = true + } + + return config, nil +} + +func classifyStructuredOutputError(err error, expectSchema bool) (error, bool) { + if err == nil || !expectSchema { + return nil, false + } + var streamErr *ThreadStreamError + if errors.As(err, &streamErr) { + return nil, false + } + + message := err.Error() + if message == "" { + return &SchemaViolationError{}, true + } + + lower := strings.ToLower(message) + if strings.Contains(lower, "schema") || strings.Contains(lower, "structured output") || strings.Contains(lower, "validation") { + return &SchemaViolationError{Message: message}, true + } + return nil, false +} + +func decodeStructuredMessage[T any](msg AgentMessageItem, final bool) (RunStreamedJSONUpdate[T], error) { + var value T + if err := json.Unmarshal([]byte(msg.Text), &value); err != nil { + if final { + return RunStreamedJSONUpdate[T]{}, fmt.Errorf("decode structured output: %w", err) + } + return RunStreamedJSONUpdate[T]{}, err + } + return RunStreamedJSONUpdate[T]{ + Value: value, + Raw: msg.Text, + Final: final, + }, nil +} + +func inferSchemaForType[T any]() (*jsonschema.Schema, error) { + t := reflect.TypeOf((*T)(nil)).Elem() + if t == nil { + return nil, errors.New("cannot infer schema for nil type") + } + ref := &jsonschema.Reflector{} + return ref.ReflectFromType(t), nil +} + +type sharedError struct { + mu sync.Mutex + err error +} + +func (s *sharedError) set(err error) { + if err == nil { + return + } + s.mu.Lock() + if s.err == nil { + s.err = err + } + s.mu.Unlock() +} + +func (s *sharedError) get() error { + s.mu.Lock() + defer s.mu.Unlock() + return s.err +} diff --git a/structured_output_test.go b/structured_output_test.go new file mode 100644 index 0000000..99de60f --- /dev/null +++ b/structured_output_test.go @@ -0,0 +1,176 @@ +package godex + +import ( + "context" + "errors" + "testing" +) + +type structuredUpdate struct { + Headline string `json:"headline"` + NextStep string `json:"next_step"` +} + +func TestRunJSONReturnsTypedValue(t *testing.T) { + events := marshalEvents(t, []map[string]any{ + {"type": "thread.started", "thread_id": "thread_1"}, + {"type": "item.completed", "item": map[string]any{ + "id": "msg_1", + "type": "agent_message", + "text": `{"headline":"Release ready","next_step":"Ship it"}`, + }}, + {"type": "turn.completed", "usage": map[string]any{"input_tokens": 1, "cached_input_tokens": 0, "output_tokens": 1}}, + }) + + runner := &fakeRunner{t: t, batches: []fakeRun{{events: events}}} + thread := newThread(runner, CodexOptions{}, ThreadOptions{}, "") + + update, err := RunJSON[structuredUpdate](context.Background(), thread, "structured", nil) + if err != nil { + t.Fatalf("RunJSON returned error: %v", err) + } + + if update.Headline != "Release ready" || update.NextStep != "Ship it" { + t.Fatalf("unexpected update: %+v", update) + } + + if call := runner.lastCall(); call.OutputSchemaPath == "" { + t.Fatal("expected OutputSchemaPath to be set") + } +} + +func TestRunJSONSchemaViolation(t *testing.T) { + events := marshalEvents(t, []map[string]any{ + {"type": "thread.started", "thread_id": "thread_1"}, + {"type": "turn.failed", "error": map[string]any{"message": "Structured output schema violation: missing property 'headline'"}}, + }) + + runner := &fakeRunner{t: t, batches: []fakeRun{{events: events}}} + thread := newThread(runner, CodexOptions{}, ThreadOptions{}, "") + + _, err := RunJSON[structuredUpdate](context.Background(), thread, "structured", nil) + if err == nil { + t.Fatal("expected RunJSON to return error") + } + + var schemaErr *SchemaViolationError + if !errors.As(err, &schemaErr) { + t.Fatalf("expected SchemaViolationError, got %T", err) + } + if schemaErr.Message == "" { + t.Fatal("expected schema error message to be populated") + } +} + +func TestRunJSONRequiresSchemaWhenInferenceDisabled(t *testing.T) { + thread := newThread(&fakeRunner{t: t}, CodexOptions{}, ThreadOptions{}, "") + + _, err := RunJSON[structuredUpdate](context.Background(), thread, "structured", &RunJSONOptions[structuredUpdate]{ + DisableSchemaInference: true, + }) + if err == nil { + t.Fatal("expected RunJSON to fail without schema when inference disabled") + } +} + +func TestRunStreamedJSONEmitsUpdates(t *testing.T) { + events := marshalEvents(t, []map[string]any{ + {"type": "thread.started", "thread_id": "thread_1"}, + {"type": "item.updated", "item": map[string]any{ + "id": "msg_1", + "type": "agent_message", + "text": `{"headline":"Draft message","next_step":"Review"}`, + }}, + {"type": "item.completed", "item": map[string]any{ + "id": "msg_1", + "type": "agent_message", + "text": `{"headline":"Final headline","next_step":"Publish"}`, + }}, + {"type": "turn.completed", "usage": map[string]any{"input_tokens": 1, "cached_input_tokens": 0, "output_tokens": 1}}, + }) + + runner := &fakeRunner{t: t, batches: []fakeRun{{events: events}}} + thread := newThread(runner, CodexOptions{}, ThreadOptions{}, "") + + result, err := RunStreamedJSON[structuredUpdate](context.Background(), thread, "structured", nil) + if err != nil { + t.Fatalf("RunStreamedJSON returned error: %v", err) + } + defer result.Close() + + var updates []RunStreamedJSONUpdate[structuredUpdate] + for update := range result.Updates() { + updates = append(updates, update) + } + + if err := result.Wait(); err != nil { + t.Fatalf("result.Wait returned error: %v", err) + } + + if len(updates) != 2 { + t.Fatalf("expected 2 updates, got %d", len(updates)) + } + if updates[0].Final { + t.Fatal("expected first update to be non-final") + } + if !updates[1].Final { + t.Fatal("expected second update to be final") + } + if updates[1].Value.Headline != "Final headline" || updates[1].Value.NextStep != "Publish" { + t.Fatalf("unexpected final update: %+v", updates[1].Value) + } +} + +func TestRunStreamedJSONSchemaViolation(t *testing.T) { + events := marshalEvents(t, []map[string]any{ + {"type": "thread.started", "thread_id": "thread_1"}, + {"type": "turn.failed", "error": map[string]any{"message": "structured output schema violation: headline missing"}}, + }) + + runner := &fakeRunner{t: t, batches: []fakeRun{{events: events}}} + thread := newThread(runner, CodexOptions{}, ThreadOptions{}, "") + + result, err := RunStreamedJSON[structuredUpdate](context.Background(), thread, "structured", nil) + if err != nil { + t.Fatalf("RunStreamedJSON returned error: %v", err) + } + defer result.Close() + + for range result.Updates() { + // drain updates + } + + waitErr := result.Wait() + if waitErr == nil { + t.Fatal("expected Wait to return error") + } + + var schemaErr *SchemaViolationError + if !errors.As(waitErr, &schemaErr) { + t.Fatalf("expected SchemaViolationError, got %T", waitErr) + } + if schemaErr.Message == "" { + t.Fatal("expected schema error message to be populated") + } +} + +func TestRunStreamedJSONWaitWithoutUpdatesConsumer(t *testing.T) { + events := marshalEvents(t, []map[string]any{ + {"type": "thread.started", "thread_id": "thread_1"}, + {"type": "turn.completed", "usage": map[string]any{"input_tokens": 1, "cached_input_tokens": 0, "output_tokens": 1}}, + }) + + runner := &fakeRunner{t: t, batches: []fakeRun{{events: events}}} + thread := newThread(runner, CodexOptions{}, ThreadOptions{}, "") + + result, err := RunStreamedJSON[structuredUpdate](context.Background(), thread, "structured", nil) + if err != nil { + t.Fatalf("RunStreamedJSON returned error: %v", err) + } + defer result.Close() + + waitErr := result.Wait() + if !errors.Is(waitErr, ErrNoStructuredOutput) { + t.Fatalf("expected ErrNoStructuredOutput, got %v", waitErr) + } +}