Skip to content

Commit f501936

Browse files
committed
feat: Add support for streaming server-sent events
1 parent de05838 commit f501936

File tree

6 files changed

+266
-34
lines changed

6 files changed

+266
-34
lines changed

.github/workflows/ci.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,9 @@ jobs:
5353
manifest-file: .github/release-manifest.json
5454
token: ${{ secrets.GITHUB_TOKEN }}
5555
- name: Update the version number
56-
if: ${{ steps.release.outputs.prs_created == 'true' }}
56+
if: steps.release.outputs.prs_created && steps.release.outputs.pr != null
5757
run: |
58-
git config pull.ff only
58+
git config pull.rebase true
5959
git checkout ${{ fromJSON(steps.release.outputs.pr).headBranchName }}
6060
git pull origin ${{ fromJSON(steps.release.outputs.pr).headBranchName }}
6161
version=$(jq -r '."."' .github/release-manifest.json)
@@ -65,4 +65,4 @@ jobs:
6565
git config --local user.email "48985810+david-letterman@users.noreply.github.com"
6666
git add version.go
6767
git commit -m "chore: Configure the version number"
68-
git push
68+
git push origin ${{ fromJSON(steps.release.outputs.pr).headBranchName }}

README.md

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,19 +41,51 @@ claude := anthropic.NewClient(transport.Client())
4141
Once constructed, you can use the client to interact with the REST API.
4242

4343
```go
44-
data, _, err := claude.Messages.Create(
45-
context.Background(),
46-
&anthropic.CreateMessageInput{
47-
MaxTokens: 1024,
48-
Messages: []anthropic.Message{
49-
{
50-
Content: "Hello, Claude!",
51-
Role: "user",
52-
},
44+
out, _, err := claude.Messages.Create(ctx, &anthropic.CreateMessageOptions{
45+
MaxTokens: 1024,
46+
Messages: []anthropic.Message{
47+
{
48+
Content: "Hello, Claude!",
49+
Role: "user",
5350
},
54-
Model: anthropic.Claude3Opus20240229,
5551
},
56-
)
52+
Model: anthropic.Claude3Opus20240229,
53+
})
54+
```
55+
56+
#### Streaming
57+
58+
Streaming support is available.
59+
60+
```go
61+
events, _, err := claude.Messages.Stream(ctx, &anthropic.StreamMessageOptions{
62+
MaxTokens: 1024,
63+
Messages: []anthropic.Message{
64+
{
65+
Content: "Hello, Claude!",
66+
Role: "user",
67+
},
68+
},
69+
Model: anthropic.Claude3Opus20240229,
70+
})
71+
```
72+
73+
```go
74+
for {
75+
select {
76+
case <-ctx.Done():
77+
return
78+
case event := <-events:
79+
fmt.Println(event.Data)
80+
}
81+
}
82+
```
83+
84+
The Stream method is a wrapper around the Create method that sets the `stream`
85+
parameter to `true`. The Stream method returns a channel of `Event` structs.
86+
87+
```go
88+
5789
```
5890

5991
### Development and testing

messages.go

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,8 @@ type Message struct {
1212
Role string `json:"role"`
1313
}
1414

15-
type Content struct {
16-
Type string `json:"type"`
17-
Text string `json:"text"`
18-
}
19-
20-
// CreateMessageInput defines a structured list of input messages.
21-
type CreateMessageInput struct {
15+
// CreateMessageOptions ...
16+
type CreateMessageOptions struct {
2217
// Temperature defines the amount of randomness injected into the response.
2318
// Note that even with a temperature of 0.0, results will not be fully
2419
// deterministic.
@@ -59,14 +54,17 @@ type CreateMessageInput struct {
5954

6055
// CreateMessageOutput defines the response from creating a new message.
6156
type CreateMessageOutput struct {
62-
ID *string `json:"id"`
63-
Type *string `json:"type"`
64-
Role *string `json:"role"`
65-
Model *string `json:"model"`
66-
StopSequence *string `json:"stop_sequence"`
67-
StopReason *string `json:"stop_reason"`
68-
Usage *Usage `json:"usage"`
69-
Content []*Content `json:"content"`
57+
ID *string `json:"id"`
58+
Type *string `json:"type"`
59+
Role *string `json:"role"`
60+
Model *string `json:"model"`
61+
StopSequence *string `json:"stop_sequence"`
62+
StopReason *string `json:"stop_reason"`
63+
Usage *Usage `json:"usage"`
64+
Content []*struct {
65+
Type string `json:"type"`
66+
Text string `json:"text"`
67+
} `json:"content"`
7068
}
7169

7270
// String implements the fmt.Stringer interface for CreateMessageOutput.
@@ -77,18 +75,18 @@ func (c *CreateMessageOutput) String() string {
7775
// Create creates a new message using the provided options.
7876
func (c *MessagesService) Create(
7977
ctx context.Context,
80-
in *CreateMessageInput,
78+
opts *CreateMessageOptions,
8179
) (*CreateMessageOutput, *http.Response, error) {
82-
req, err := c.client.NewRequest(http.MethodPost, "messages", in)
80+
req, err := c.client.NewRequest(http.MethodPost, "messages", opts)
8381
if err != nil {
8482
return nil, nil, err
8583
}
8684

87-
out := new(CreateMessageOutput)
88-
resp, err := c.client.Do(ctx, req, out)
85+
output := new(CreateMessageOutput)
86+
resp, err := c.client.Do(ctx, req, output)
8987
if err != nil {
9088
return nil, resp, err
9189
}
9290

93-
return out, resp, nil
91+
return output, resp, nil
9492
}

messages_test.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,12 @@
11
package anthropic
2+
3+
import (
4+
"testing"
5+
)
6+
7+
func TestMessagesService_Create(t *testing.T) {
8+
_ = &CreateMessageOptions{
9+
MaxTokens: 1024,
10+
}
11+
t.SkipNow()
12+
}

streaming.go

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
package anthropic
2+
3+
import (
4+
"bufio"
5+
"context"
6+
"encoding/json"
7+
"io"
8+
"net/http"
9+
)
10+
11+
// StreamMessageOptions defines the options available when streaming server-sent
12+
// events from the Anthropic REST API, the StreamMessageOptions definition is
13+
// currently identical to the CreateMessageOptions definition, but the
14+
// StreamMessageOptions type has a custom MarshalJSON implementation that will
15+
// append a stream field and set it to true.
16+
type StreamMessageOptions struct {
17+
// Temperature defines the amount of randomness injected into the response.
18+
// Note that even with a temperature of 0.0, results will not be fully
19+
// deterministic.
20+
Temperature *float64 `json:"temperature,omitempty"`
21+
// TopK is used to remove long tail low probability responses by only
22+
// sampling from the top K options for each subsequent token.
23+
// Recommended for advanced use cases only. You usually only need to use
24+
// Temperature.
25+
TopK *int `json:"top_k,omitempty"`
26+
// TopP is the nucleus-sampling parameter. Temperature or TopP should be
27+
// used, but not both.
28+
// Recommended for advanced use cases only. You usually only need to use
29+
// Temperature.
30+
TopP *float64 `json:"top_p,omitempty"`
31+
// Model defines the language model that will be used to complete the
32+
// prompt. See model.go for a list of available models.
33+
Model LanguageModel `json:"model"`
34+
// System provides a means of specifying context and instructions to the
35+
// model, such as specifying a particular goal or role.
36+
System string `json:"system,omitempty"`
37+
// Messages are the input messages, models are trained to operate on
38+
// alternating user and assistant conversational turns. When creating a new
39+
// message, prior conversational turns can be specified with this field,
40+
// and the model generates the next Message in the conversation.
41+
Messages []Message `json:"messages"`
42+
// StopSequences defines custom text sequences that will cause the model to
43+
// stop generating. If the model encounters any of the sequences, the
44+
// StopReason field will be set to "stop_sequence" and the response
45+
// StopSequence field will be set to the sequence that caused the model to
46+
// stop.
47+
StopSequences []string `json:"stop_sequences,omitempty"`
48+
// MaxTokens defines the maximum number of tokens to generate before
49+
// stopping. Token generation may stop before reaching this limit, this only
50+
// specifies the absolute maximum number of tokens to generate. Different
51+
// models have different maximum token limits.
52+
MaxTokens int `json:"max_tokens"`
53+
}
54+
55+
// MarshalJSON implements the json.Marshaler interface for StreamMessageOptions.
56+
// When StreamMessageOptions is marshalled to JSON, a stream field will be added
57+
// and set to a boolean value of true.
58+
func (c *StreamMessageOptions) MarshalJSON() ([]byte, error) {
59+
return json.Marshal(struct {
60+
Temperature *float64 `json:"temperature,omitempty"`
61+
TopK *int `json:"top_k,omitempty"`
62+
TopP *float64 `json:"top_p,omitempty"`
63+
Model LanguageModel `json:"model,omitempty"`
64+
System string `json:"system,omitempty"`
65+
Messages []Message `json:"messages,omitempty"`
66+
StopSequences []string `json:"stop_sequences,omitempty"`
67+
MaxTokens int `json:"max_tokens,omitempty"`
68+
Stream bool `json:"stream"`
69+
}{
70+
Temperature: c.Temperature,
71+
TopK: c.TopK,
72+
TopP: c.TopP,
73+
Model: c.Model,
74+
System: c.System,
75+
Messages: c.Messages,
76+
StopSequences: c.StopSequences,
77+
MaxTokens: c.MaxTokens,
78+
Stream: true,
79+
})
80+
}
81+
82+
// ServerSentEvent defines a server-sent event.
83+
type ServerSentEvent struct {
84+
Event *string
85+
Data string
86+
Raw []string
87+
}
88+
89+
// Stream creates a new message using the provided options and streams the
90+
// response using server-sent events. This is a convenience method that
91+
// combines the Create and Stream methods.
92+
func (c *MessagesService) Stream(
93+
ctx context.Context,
94+
opts *StreamMessageOptions,
95+
) (*<-chan ServerSentEvent, *http.Response, error) {
96+
req, err := c.client.NewRequest(http.MethodPost, "messages", opts)
97+
if err != nil {
98+
return nil, nil, err
99+
}
100+
101+
resp, err := c.client.Do(ctx, req, nil)
102+
if err != nil {
103+
return nil, resp, err
104+
}
105+
//goland:noinspection GoUnhandledErrorResult
106+
defer resp.Body.Close()
107+
108+
output, err := newServerSentEventStream(resp.Body)
109+
110+
return output, resp, err
111+
}
112+
113+
func newServerSentEventStream(body io.ReadCloser) (*<-chan ServerSentEvent, error) {
114+
scanner := bufio.NewScanner(body)
115+
scanner.Buffer(make([]byte, 4096), bufio.MaxScanTokenSize)
116+
scanner.Split(func(data []byte, atEOF bool) (int, []byte, error) {
117+
return 0, nil, nil
118+
})
119+
120+
// TODO
121+
122+
return new(<-chan ServerSentEvent), nil
123+
}

streaming_test.go

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
package anthropic
2+
3+
import (
4+
"bytes"
5+
"encoding/json"
6+
"testing"
7+
)
8+
9+
func TestStreamMessageOptions_MarshalJSON_Empty(t *testing.T) {
10+
options, err := json.Marshal(&StreamMessageOptions{})
11+
if err != nil {
12+
t.Error(err)
13+
}
14+
15+
expected := []byte(`{"stream":true}`)
16+
if !bytes.Equal(options, expected) {
17+
t.Errorf("options = %q, want %q", options, expected)
18+
}
19+
}
20+
21+
func TestStreamMessageOptions_MarshalJSON_Initialised(t *testing.T) {
22+
options, err := json.Marshal(&StreamMessageOptions{
23+
MaxTokens: 512,
24+
Messages: []Message{
25+
{
26+
Content: "This is a test message.",
27+
Role: "user",
28+
},
29+
},
30+
})
31+
if err != nil {
32+
t.Error(err)
33+
}
34+
35+
expected := []byte(`{"messages":[{"content":"This is a test message.","role":"user"}],"max_tokens":512,"stream":true}`)
36+
if !bytes.Equal(options, expected) {
37+
t.Errorf("options = %q, want %q", options, expected)
38+
}
39+
}
40+
41+
func TestMessagesService_Stream(t *testing.T) {
42+
bytes.NewBufferString(`
43+
event: message_start
44+
data: {"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-3-opus-20240229", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 25, "output_tokens": 1}}}
45+
46+
event: content_block_start
47+
data: {"type": "content_block_start", "index": 0, "content_block": {"type": "text", "text": ""}}
48+
49+
event: ping
50+
data: {"type": "ping"}
51+
52+
event: content_block_delta
53+
data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "Hello"}}
54+
55+
event: content_block_delta
56+
data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "!"}}
57+
58+
event: content_block_stop
59+
data: {"type": "content_block_stop", "index": 0}
60+
61+
event: message_delta
62+
data: {"type": "message_delta", "delta": {"stop_reason": "end_turn", "stop_sequence":null}, "usage": {"output_tokens": 15}}
63+
64+
event: message_stop
65+
data: {"type": "message_stop"}
66+
`)
67+
t.SkipNow()
68+
}

0 commit comments

Comments
 (0)