diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index fb9c6d2..362a65e 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -53,9 +53,9 @@ jobs: manifest-file: .github/release-manifest.json token: ${{ secrets.GITHUB_TOKEN }} - name: Update the version number - if: ${{ steps.release.outputs.prs_created == 'true' }} + if: steps.release.outputs.prs_created && steps.release.outputs.pr != null run: | - git config pull.ff only + git config pull.rebase true git checkout ${{ fromJSON(steps.release.outputs.pr).headBranchName }} git pull origin ${{ fromJSON(steps.release.outputs.pr).headBranchName }} version=$(jq -r '."."' .github/release-manifest.json) @@ -65,4 +65,4 @@ jobs: git config --local user.email "48985810+david-letterman@users.noreply.github.com" git add version.go git commit -m "chore: Configure the version number" - git push + git push origin ${{ fromJSON(steps.release.outputs.pr).headBranchName }} diff --git a/README.md b/README.md index 293cea5..4ba5903 100644 --- a/README.md +++ b/README.md @@ -41,19 +41,51 @@ claude := anthropic.NewClient(transport.Client()) Once constructed, you can use the client to interact with the REST API. ```go -data, _, err := claude.Messages.Create( - context.Background(), - &anthropic.CreateMessageInput{ - MaxTokens: 1024, - Messages: []anthropic.Message{ - { - Content: "Hello, Claude!", - Role: "user", - }, +data, _, err := claude.Messages.Create(ctx, &anthropic.CreateMessageInput{ + MaxTokens: 1024, + Messages: []anthropic.Message{ + { + Content: "Hello, Claude!", + Role: "user", }, - Model: anthropic.Claude3Opus20240229, }, -) + Model: anthropic.Claude3Opus20240229, +}) +``` + +#### Streaming + +Streaming support is available. + +```go +events, _, err := claude.Messages.Stream(ctx, &anthropic.CreateMessageInput{ + MaxTokens: 1024, + Messages: []anthropic.Message{ + { + Content: "Hello, Claude!", + Role: "user", + }, + }, + Model: anthropic.Claude3Opus20240229, +}) +``` + +```go +for { + select { + case <-ctx.Done(): + return + case event := <-events: + fmt.Println(event.Data) + } +} +``` + +The Stream method is a wrapper around the Create method that sets the `stream` +parameter to `true`. The Stream method returns a channel of `Event` structs. + +```go + ``` ### Development and testing diff --git a/anthropic.go b/anthropic.go index 5203dd2..09e6895 100644 --- a/anthropic.go +++ b/anthropic.go @@ -15,7 +15,7 @@ import ( const ( defaultAPIVersion = "2023-06-01" defaultBaseURL = "https://api.anthropic.com/v1/" - defaultBaseUserAgent = "anthropic-sdk-go" + defaultBaseUserAgent = "unfunco/anthropic-sdk-go" ) // Client manages communication with the Anthropic REST API. @@ -77,7 +77,7 @@ func (c *Client) NewRequest(method, path string, body any) (*http.Request, error } req.Header.Set("anthropic-version", defaultAPIVersion) - req.Header.Set("user-agent", defaultBaseUserAgent+"/"+semanticVersion) + req.Header.Set("user-agent", defaultBaseUserAgent+"@"+semanticVersion) if body != nil { req.Header.Set("content-type", "application/json") diff --git a/messages.go b/messages.go index 920d0f4..ee2919f 100644 --- a/messages.go +++ b/messages.go @@ -2,6 +2,7 @@ package anthropic import ( "context" + "encoding/json" "net/http" ) @@ -17,8 +18,8 @@ type Content struct { Text string `json:"text"` } -// CreateMessageInput defines a structured list of input messages. -type CreateMessageInput struct { +// CreateMessageOptions ... +type CreateMessageOptions struct { // Temperature defines the amount of randomness injected into the response. // Note that even with a temperature of 0.0, results will not be fully // deterministic. @@ -57,6 +58,31 @@ type CreateMessageInput struct { MaxTokens int `json:"max_tokens"` } +// MarshalJSON implements the json.Marshaler interface for StreamMessageOptions. +func (c *CreateMessageOptions) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Temperature *float64 `json:"temperature,omitempty"` + TopK *int `json:"top_k,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + Model LanguageModel `json:"model"` + System string `json:"system,omitempty"` + Messages []Message `json:"messages"` + StopSequences []string `json:"stop_sequences,omitempty"` + MaxTokens int `json:"max_tokens"` + Stream bool `json:"stream"` + }{ + Temperature: c.Temperature, + TopK: c.TopK, + TopP: c.TopP, + Model: c.Model, + System: c.System, + Messages: c.Messages, + StopSequences: c.StopSequences, + MaxTokens: c.MaxTokens, + Stream: false, + }) +} + // CreateMessageOutput defines the response from creating a new message. type CreateMessageOutput struct { ID *string `json:"id"` @@ -77,18 +103,18 @@ func (c *CreateMessageOutput) String() string { // Create creates a new message using the provided options. func (c *MessagesService) Create( ctx context.Context, - in *CreateMessageInput, + input *CreateMessageOptions, ) (*CreateMessageOutput, *http.Response, error) { - req, err := c.client.NewRequest(http.MethodPost, "messages", in) + req, err := c.client.NewRequest(http.MethodPost, "messages", input) if err != nil { return nil, nil, err } - out := new(CreateMessageOutput) - resp, err := c.client.Do(ctx, req, out) + output := new(CreateMessageOutput) + resp, err := c.client.Do(ctx, req, output) if err != nil { return nil, resp, err } - return out, resp, nil + return output, resp, nil } diff --git a/messages_test.go b/messages_test.go index cdf3a1b..01621df 100644 --- a/messages_test.go +++ b/messages_test.go @@ -1 +1,42 @@ package anthropic + +import ( + "bytes" + "testing" +) + +func TestMessagesService_Create(t *testing.T) { + _ = &CreateMessageOptions{ + MaxTokens: 1024, + } + t.SkipNow() +} + +func TestMessagesService_Stream(t *testing.T) { + bytes.NewBufferString(` +event: message_start +data: {"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-3-opus-20240229", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 25, "output_tokens": 1}}} + +event: content_block_start +data: {"type": "content_block_start", "index": 0, "content_block": {"type": "text", "text": ""}} + +event: ping +data: {"type": "ping"} + +event: content_block_delta +data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "Hello"}} + +event: content_block_delta +data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "!"}} + +event: content_block_stop +data: {"type": "content_block_stop", "index": 0} + +event: message_delta +data: {"type": "message_delta", "delta": {"stop_reason": "end_turn", "stop_sequence":null}, "usage": {"output_tokens": 15}} + +event: message_stop +data: {"type": "message_stop"} +`) + t.SkipNow() +} diff --git a/streaming.go b/streaming.go new file mode 100644 index 0000000..dd06e1c --- /dev/null +++ b/streaming.go @@ -0,0 +1,117 @@ +package anthropic + +import ( + "bufio" + "context" + "encoding/json" + "io" + "net/http" +) + +// StreamMessageOptions ... +type StreamMessageOptions struct { + // Temperature defines the amount of randomness injected into the response. + // Note that even with a temperature of 0.0, results will not be fully + // deterministic. + Temperature *float64 `json:"temperature,omitempty"` + // TopK is used to remove long tail low probability responses by only + // sampling from the top K options for each subsequent token. + // Recommended for advanced use cases only. You usually only need to use + // Temperature. + TopK *int `json:"top_k,omitempty"` + // TopP is the nucleus-sampling parameter. Temperature or TopP should be + // used, but not both. + // Recommended for advanced use cases only. You usually only need to use + // Temperature. + TopP *float64 `json:"top_p,omitempty"` + // Model defines the language model that will be used to complete the + // prompt. See model.go for a list of available models. + Model LanguageModel `json:"model"` + // System provides a means of specifying context and instructions to the + // model, such as specifying a particular goal or role. + System string `json:"system,omitempty"` + // Messages are the input messages, models are trained to operate on + // alternating user and assistant conversational turns. When creating a new + // message, prior conversational turns can be specified with this field, + // and the model generates the next Message in the conversation. + Messages []Message `json:"messages"` + // StopSequences defines custom text sequences that will cause the model to + // stop generating. If the model encounters any of the sequences, the + // StopReason field will be set to "stop_sequence" and the response + // StopSequence field will be set to the sequence that caused the model to + // stop. + StopSequences []string `json:"stop_sequences,omitempty"` + // MaxTokens defines the maximum number of tokens to generate before + // stopping. Token generation may stop before reaching this limit, this only + // specifies the absolute maximum number of tokens to generate. Different + // models have different maximum token limits. + MaxTokens int `json:"max_tokens"` +} + +// MarshalJSON implements the json.Marshaler interface for StreamMessageOptions. +func (c *StreamMessageOptions) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Temperature *float64 `json:"temperature,omitempty"` + TopK *int `json:"top_k,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + Model LanguageModel `json:"model"` + System string `json:"system,omitempty"` + Messages []Message `json:"messages"` + StopSequences []string `json:"stop_sequences,omitempty"` + MaxTokens int `json:"max_tokens"` + Stream bool `json:"stream"` + }{ + Temperature: c.Temperature, + TopK: c.TopK, + TopP: c.TopP, + Model: c.Model, + System: c.System, + Messages: c.Messages, + StopSequences: c.StopSequences, + MaxTokens: c.MaxTokens, + Stream: true, + }) +} + +// ServerSentEvent defines a server-sent event. +type ServerSentEvent struct { + Event *string + Data string + Raw []string +} + +// Stream creates a new message using the provided options and streams the +// response using server-sent events. This is a convenience method that +// combines the Create and Stream methods. +func (c *MessagesService) Stream( + ctx context.Context, + input *StreamMessageOptions, +) (*<-chan ServerSentEvent, *http.Response, error) { + req, err := c.client.NewRequest(http.MethodPost, "messages", input) + if err != nil { + return nil, nil, err + } + + resp, err := c.client.Do(ctx, req, nil) + if err != nil { + return nil, resp, err + } + //goland:noinspection GoUnhandledErrorResult + defer resp.Body.Close() + + output, err := newServerSentEventStream(resp.Body) + + return output, resp, err +} + +func newServerSentEventStream(body io.ReadCloser) (*<-chan ServerSentEvent, error) { + scanner := bufio.NewScanner(body) + scanner.Buffer(make([]byte, 4096), bufio.MaxScanTokenSize) + scanner.Split(func(data []byte, atEOF bool) (int, []byte, error) { + return 0, nil, nil + }) + + // TODO + + return new(<-chan ServerSentEvent), nil +} diff --git a/streaming_test.go b/streaming_test.go new file mode 100644 index 0000000..cdf3a1b --- /dev/null +++ b/streaming_test.go @@ -0,0 +1 @@ +package anthropic