-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Add support for streaming server-sent events
- Loading branch information
Showing
6 changed files
with
246 additions
and
34 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,12 @@ | ||
package anthropic | ||
|
||
import ( | ||
"testing" | ||
) | ||
|
||
func TestMessagesService_Create(t *testing.T) { | ||
_ = &CreateMessageOptions{ | ||
MaxTokens: 1024, | ||
} | ||
t.SkipNow() | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
package anthropic | ||
|
||
import ( | ||
"bufio" | ||
"context" | ||
"encoding/json" | ||
"io" | ||
"net/http" | ||
) | ||
|
||
// StreamMessageOptions defines the options available when streaming server-sent | ||
// events from the Anthropic REST API, the StreamMessageOptions definition is | ||
// currently identical to the CreateMessageOptions definition, but the | ||
// StreamMessageOptions type has a custom MarshalJSON implementation that will | ||
// append a stream field and set it to true. | ||
type StreamMessageOptions struct { | ||
// Temperature defines the amount of randomness injected into the response. | ||
// Note that even with a temperature of 0.0, results will not be fully | ||
// deterministic. | ||
Temperature *float64 `json:"temperature,omitempty"` | ||
// TopK is used to remove long tail low probability responses by only | ||
// sampling from the top K options for each subsequent token. | ||
// Recommended for advanced use cases only. You usually only need to use | ||
// Temperature. | ||
TopK *int `json:"top_k,omitempty"` | ||
// TopP is the nucleus-sampling parameter. Temperature or TopP should be | ||
// used, but not both. | ||
// Recommended for advanced use cases only. You usually only need to use | ||
// Temperature. | ||
TopP *float64 `json:"top_p,omitempty"` | ||
// Model defines the language model that will be used to complete the | ||
// prompt. See model.go for a list of available models. | ||
Model LanguageModel `json:"model"` | ||
// System provides a means of specifying context and instructions to the | ||
// model, such as specifying a particular goal or role. | ||
System string `json:"system,omitempty"` | ||
// Messages are the input messages, models are trained to operate on | ||
// alternating user and assistant conversational turns. When creating a new | ||
// message, prior conversational turns can be specified with this field, | ||
// and the model generates the next Message in the conversation. | ||
Messages []Message `json:"messages"` | ||
// StopSequences defines custom text sequences that will cause the model to | ||
// stop generating. If the model encounters any of the sequences, the | ||
// StopReason field will be set to "stop_sequence" and the response | ||
// StopSequence field will be set to the sequence that caused the model to | ||
// stop. | ||
StopSequences []string `json:"stop_sequences,omitempty"` | ||
// MaxTokens defines the maximum number of tokens to generate before | ||
// stopping. Token generation may stop before reaching this limit, this only | ||
// specifies the absolute maximum number of tokens to generate. Different | ||
// models have different maximum token limits. | ||
MaxTokens int `json:"max_tokens"` | ||
} | ||
|
||
// MarshalJSON implements the json.Marshaler interface for StreamMessageOptions. | ||
// When StreamMessageOptions is marshalled to JSON, a stream field will be added | ||
// and set to a boolean value of true. | ||
func (c *StreamMessageOptions) MarshalJSON() ([]byte, error) { | ||
return json.Marshal(struct { | ||
Temperature *float64 `json:"temperature,omitempty"` | ||
TopK *int `json:"top_k,omitempty"` | ||
TopP *float64 `json:"top_p,omitempty"` | ||
Model LanguageModel `json:"model"` | ||
System string `json:"system,omitempty"` | ||
Messages []Message `json:"messages"` | ||
StopSequences []string `json:"stop_sequences,omitempty"` | ||
MaxTokens int `json:"max_tokens"` | ||
Stream bool `json:"stream"` | ||
}{ | ||
Temperature: c.Temperature, | ||
TopK: c.TopK, | ||
TopP: c.TopP, | ||
Model: c.Model, | ||
System: c.System, | ||
Messages: c.Messages, | ||
StopSequences: c.StopSequences, | ||
MaxTokens: c.MaxTokens, | ||
Stream: true, | ||
}) | ||
} | ||
|
||
// ServerSentEvent defines a server-sent event. | ||
type ServerSentEvent struct { | ||
Event *string | ||
Data string | ||
Raw []string | ||
} | ||
|
||
// Stream creates a new message using the provided options and streams the | ||
// response using server-sent events. This is a convenience method that | ||
// combines the Create and Stream methods. | ||
func (c *MessagesService) Stream( | ||
ctx context.Context, | ||
opts *StreamMessageOptions, | ||
) (*<-chan ServerSentEvent, *http.Response, error) { | ||
req, err := c.client.NewRequest(http.MethodPost, "messages", opts) | ||
if err != nil { | ||
return nil, nil, err | ||
} | ||
|
||
resp, err := c.client.Do(ctx, req, nil) | ||
if err != nil { | ||
return nil, resp, err | ||
} | ||
//goland:noinspection GoUnhandledErrorResult | ||
defer resp.Body.Close() | ||
|
||
output, err := newServerSentEventStream(resp.Body) | ||
|
||
return output, resp, err | ||
} | ||
|
||
func newServerSentEventStream(body io.ReadCloser) (*<-chan ServerSentEvent, error) { | ||
scanner := bufio.NewScanner(body) | ||
scanner.Buffer(make([]byte, 4096), bufio.MaxScanTokenSize) | ||
scanner.Split(func(data []byte, atEOF bool) (int, []byte, error) { | ||
return 0, nil, nil | ||
}) | ||
|
||
// TODO | ||
|
||
return new(<-chan ServerSentEvent), nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
package anthropic | ||
|
||
import ( | ||
"bytes" | ||
"encoding/json" | ||
"testing" | ||
) | ||
|
||
func TestStreamMessageOptions_MarshalJSON_Empty(t *testing.T) { | ||
options, err := json.Marshal(&StreamMessageOptions{}) | ||
if err != nil { | ||
t.Error(err) | ||
} | ||
|
||
expected := []byte(`{"messages":[],max_tokens":0,"stream":true}`) | ||
if !bytes.Equal(options, expected) { | ||
t.Errorf("options = %q, want %q", options, expected) | ||
} | ||
} | ||
|
||
func TestMessagesService_Stream(t *testing.T) { | ||
bytes.NewBufferString(` | ||
event: message_start | ||
data: {"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-3-opus-20240229", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 25, "output_tokens": 1}}} | ||
event: content_block_start | ||
data: {"type": "content_block_start", "index": 0, "content_block": {"type": "text", "text": ""}} | ||
event: ping | ||
data: {"type": "ping"} | ||
event: content_block_delta | ||
data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "Hello"}} | ||
event: content_block_delta | ||
data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "!"}} | ||
event: content_block_stop | ||
data: {"type": "content_block_stop", "index": 0} | ||
event: message_delta | ||
data: {"type": "message_delta", "delta": {"stop_reason": "end_turn", "stop_sequence":null}, "usage": {"output_tokens": 15}} | ||
event: message_stop | ||
data: {"type": "message_stop"} | ||
`) | ||
t.SkipNow() | ||
} |