From daa299a664d64b870f1ebcc4ae41497b29ccd3c1 Mon Sep 17 00:00:00 2001 From: Daniel Morris Date: Sat, 4 May 2024 10:41:30 +0100 Subject: [PATCH] feat: Add support for streaming server-sent events --- .github/workflows/ci.yaml | 6 +- README.md | 54 +++++++++++++---- messages.go | 38 ++++++------ messages_test.go | 11 ++++ streaming.go | 123 ++++++++++++++++++++++++++++++++++++++ streaming_test.go | 48 +++++++++++++++ 6 files changed, 246 insertions(+), 34 deletions(-) create mode 100644 streaming.go create mode 100644 streaming_test.go diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index fb9c6d2..362a65e 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -53,9 +53,9 @@ jobs: manifest-file: .github/release-manifest.json token: ${{ secrets.GITHUB_TOKEN }} - name: Update the version number - if: ${{ steps.release.outputs.prs_created == 'true' }} + if: steps.release.outputs.prs_created && steps.release.outputs.pr != null run: | - git config pull.ff only + git config pull.rebase true git checkout ${{ fromJSON(steps.release.outputs.pr).headBranchName }} git pull origin ${{ fromJSON(steps.release.outputs.pr).headBranchName }} version=$(jq -r '."."' .github/release-manifest.json) @@ -65,4 +65,4 @@ jobs: git config --local user.email "48985810+david-letterman@users.noreply.github.com" git add version.go git commit -m "chore: Configure the version number" - git push + git push origin ${{ fromJSON(steps.release.outputs.pr).headBranchName }} diff --git a/README.md b/README.md index 293cea5..9990847 100644 --- a/README.md +++ b/README.md @@ -41,19 +41,51 @@ claude := anthropic.NewClient(transport.Client()) Once constructed, you can use the client to interact with the REST API. ```go -data, _, err := claude.Messages.Create( - context.Background(), - &anthropic.CreateMessageInput{ - MaxTokens: 1024, - Messages: []anthropic.Message{ - { - Content: "Hello, Claude!", - Role: "user", - }, +out, _, err := claude.Messages.Create(ctx, &anthropic.CreateMessageOptions{ + MaxTokens: 1024, + Messages: []anthropic.Message{ + { + Content: "Hello, Claude!", + Role: "user", }, - Model: anthropic.Claude3Opus20240229, }, -) + Model: anthropic.Claude3Opus20240229, +}) +``` + +#### Streaming + +Streaming support is available. + +```go +events, _, err := claude.Messages.Stream(ctx, &anthropic.StreamMessageOptions{ + MaxTokens: 1024, + Messages: []anthropic.Message{ + { + Content: "Hello, Claude!", + Role: "user", + }, + }, + Model: anthropic.Claude3Opus20240229, +}) +``` + +```go +for { + select { + case <-ctx.Done(): + return + case event := <-events: + fmt.Println(event.Data) + } +} +``` + +The Stream method is a wrapper around the Create method that sets the `stream` +parameter to `true`. The Stream method returns a channel of `Event` structs. + +```go + ``` ### Development and testing diff --git a/messages.go b/messages.go index 920d0f4..1d68166 100644 --- a/messages.go +++ b/messages.go @@ -12,13 +12,8 @@ type Message struct { Role string `json:"role"` } -type Content struct { - Type string `json:"type"` - Text string `json:"text"` -} - -// CreateMessageInput defines a structured list of input messages. -type CreateMessageInput struct { +// CreateMessageOptions ... +type CreateMessageOptions struct { // Temperature defines the amount of randomness injected into the response. // Note that even with a temperature of 0.0, results will not be fully // deterministic. @@ -59,14 +54,17 @@ type CreateMessageInput struct { // CreateMessageOutput defines the response from creating a new message. type CreateMessageOutput struct { - ID *string `json:"id"` - Type *string `json:"type"` - Role *string `json:"role"` - Model *string `json:"model"` - StopSequence *string `json:"stop_sequence"` - StopReason *string `json:"stop_reason"` - Usage *Usage `json:"usage"` - Content []*Content `json:"content"` + ID *string `json:"id"` + Type *string `json:"type"` + Role *string `json:"role"` + Model *string `json:"model"` + StopSequence *string `json:"stop_sequence"` + StopReason *string `json:"stop_reason"` + Usage *Usage `json:"usage"` + Content []*struct { + Type string `json:"type"` + Text string `json:"text"` + } `json:"content"` } // String implements the fmt.Stringer interface for CreateMessageOutput. @@ -77,18 +75,18 @@ func (c *CreateMessageOutput) String() string { // Create creates a new message using the provided options. func (c *MessagesService) Create( ctx context.Context, - in *CreateMessageInput, + opts *CreateMessageOptions, ) (*CreateMessageOutput, *http.Response, error) { - req, err := c.client.NewRequest(http.MethodPost, "messages", in) + req, err := c.client.NewRequest(http.MethodPost, "messages", opts) if err != nil { return nil, nil, err } - out := new(CreateMessageOutput) - resp, err := c.client.Do(ctx, req, out) + output := new(CreateMessageOutput) + resp, err := c.client.Do(ctx, req, output) if err != nil { return nil, resp, err } - return out, resp, nil + return output, resp, nil } diff --git a/messages_test.go b/messages_test.go index cdf3a1b..81f9746 100644 --- a/messages_test.go +++ b/messages_test.go @@ -1 +1,12 @@ package anthropic + +import ( + "testing" +) + +func TestMessagesService_Create(t *testing.T) { + _ = &CreateMessageOptions{ + MaxTokens: 1024, + } + t.SkipNow() +} diff --git a/streaming.go b/streaming.go new file mode 100644 index 0000000..4991a3b --- /dev/null +++ b/streaming.go @@ -0,0 +1,123 @@ +package anthropic + +import ( + "bufio" + "context" + "encoding/json" + "io" + "net/http" +) + +// StreamMessageOptions defines the options available when streaming server-sent +// events from the Anthropic REST API, the StreamMessageOptions definition is +// currently identical to the CreateMessageOptions definition, but the +// StreamMessageOptions type has a custom MarshalJSON implementation that will +// append a stream field and set it to true. +type StreamMessageOptions struct { + // Temperature defines the amount of randomness injected into the response. + // Note that even with a temperature of 0.0, results will not be fully + // deterministic. + Temperature *float64 `json:"temperature,omitempty"` + // TopK is used to remove long tail low probability responses by only + // sampling from the top K options for each subsequent token. + // Recommended for advanced use cases only. You usually only need to use + // Temperature. + TopK *int `json:"top_k,omitempty"` + // TopP is the nucleus-sampling parameter. Temperature or TopP should be + // used, but not both. + // Recommended for advanced use cases only. You usually only need to use + // Temperature. + TopP *float64 `json:"top_p,omitempty"` + // Model defines the language model that will be used to complete the + // prompt. See model.go for a list of available models. + Model LanguageModel `json:"model"` + // System provides a means of specifying context and instructions to the + // model, such as specifying a particular goal or role. + System string `json:"system,omitempty"` + // Messages are the input messages, models are trained to operate on + // alternating user and assistant conversational turns. When creating a new + // message, prior conversational turns can be specified with this field, + // and the model generates the next Message in the conversation. + Messages []Message `json:"messages"` + // StopSequences defines custom text sequences that will cause the model to + // stop generating. If the model encounters any of the sequences, the + // StopReason field will be set to "stop_sequence" and the response + // StopSequence field will be set to the sequence that caused the model to + // stop. + StopSequences []string `json:"stop_sequences,omitempty"` + // MaxTokens defines the maximum number of tokens to generate before + // stopping. Token generation may stop before reaching this limit, this only + // specifies the absolute maximum number of tokens to generate. Different + // models have different maximum token limits. + MaxTokens int `json:"max_tokens"` +} + +// MarshalJSON implements the json.Marshaler interface for StreamMessageOptions. +// When StreamMessageOptions is marshalled to JSON, a stream field will be added +// and set to a boolean value of true. +func (c *StreamMessageOptions) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Temperature *float64 `json:"temperature,omitempty"` + TopK *int `json:"top_k,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + Model LanguageModel `json:"model"` + System string `json:"system,omitempty"` + Messages []Message `json:"messages"` + StopSequences []string `json:"stop_sequences,omitempty"` + MaxTokens int `json:"max_tokens"` + Stream bool `json:"stream"` + }{ + Temperature: c.Temperature, + TopK: c.TopK, + TopP: c.TopP, + Model: c.Model, + System: c.System, + Messages: c.Messages, + StopSequences: c.StopSequences, + MaxTokens: c.MaxTokens, + Stream: true, + }) +} + +// ServerSentEvent defines a server-sent event. +type ServerSentEvent struct { + Event *string + Data string + Raw []string +} + +// Stream creates a new message using the provided options and streams the +// response using server-sent events. This is a convenience method that +// combines the Create and Stream methods. +func (c *MessagesService) Stream( + ctx context.Context, + opts *StreamMessageOptions, +) (*<-chan ServerSentEvent, *http.Response, error) { + req, err := c.client.NewRequest(http.MethodPost, "messages", opts) + if err != nil { + return nil, nil, err + } + + resp, err := c.client.Do(ctx, req, nil) + if err != nil { + return nil, resp, err + } + //goland:noinspection GoUnhandledErrorResult + defer resp.Body.Close() + + output, err := newServerSentEventStream(resp.Body) + + return output, resp, err +} + +func newServerSentEventStream(body io.ReadCloser) (*<-chan ServerSentEvent, error) { + scanner := bufio.NewScanner(body) + scanner.Buffer(make([]byte, 4096), bufio.MaxScanTokenSize) + scanner.Split(func(data []byte, atEOF bool) (int, []byte, error) { + return 0, nil, nil + }) + + // TODO + + return new(<-chan ServerSentEvent), nil +} diff --git a/streaming_test.go b/streaming_test.go new file mode 100644 index 0000000..f84ac04 --- /dev/null +++ b/streaming_test.go @@ -0,0 +1,48 @@ +package anthropic + +import ( + "bytes" + "encoding/json" + "testing" +) + +func TestStreamMessageOptions_MarshalJSON_Empty(t *testing.T) { + options, err := json.Marshal(&StreamMessageOptions{}) + if err != nil { + t.Error(err) + } + + expected := []byte(`{"messages":[],max_tokens":0,"stream":true}`) + if !bytes.Equal(options, expected) { + t.Errorf("options = %q, want %q", options, expected) + } +} + +func TestMessagesService_Stream(t *testing.T) { + bytes.NewBufferString(` +event: message_start +data: {"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-3-opus-20240229", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 25, "output_tokens": 1}}} + +event: content_block_start +data: {"type": "content_block_start", "index": 0, "content_block": {"type": "text", "text": ""}} + +event: ping +data: {"type": "ping"} + +event: content_block_delta +data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "Hello"}} + +event: content_block_delta +data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "!"}} + +event: content_block_stop +data: {"type": "content_block_stop", "index": 0} + +event: message_delta +data: {"type": "message_delta", "delta": {"stop_reason": "end_turn", "stop_sequence":null}, "usage": {"output_tokens": 15}} + +event: message_stop +data: {"type": "message_stop"} +`) + t.SkipNow() +}