diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index cf9e378..8663418 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -14,7 +14,7 @@ jobs: build: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Go uses: actions/setup-go@v4 diff --git a/cloudflare_ai.go b/cloudflare_ai.go index b48e297..ba9887a 100644 --- a/cloudflare_ai.go +++ b/cloudflare_ai.go @@ -15,13 +15,12 @@ type CfTextGenerationResponse struct { } type CfTextGenerationMsg struct { - Role string `json:"role"` + Role string `json:"role"` Content string `json:"content"` } - -type CfTextGenerationArg struct{ - Stream bool `json:"stream,omitempty"` +type CfTextGenerationArg struct { + Stream bool `json:"stream,omitempty"` Messages []CfTextGenerationMsg `json:"messages,omitempty"` } @@ -31,17 +30,13 @@ func (c *CfTextGenerationArg) body() (io.ReadCloser, error) { return io.NopCloser(buff), err } - type CloudflareAI struct { AccountID string APIToken string - } var httpClient = &http.Client{} - - var modelsTextGeneration = []string{ //https://dash.cloudflare.com/0a76b889e644c012524110042e6f197e/ai/workers-ai //page 1 @@ -81,11 +76,11 @@ func (c *CloudflareAI) modelCheck(model string) error { return nil } } - return errors.New("model not found: "+model) + return errors.New("model not found: " + model) } func (c *CloudflareAI) Do(model string, arg *CfTextGenerationArg) (*http.Response, error) { - if c.AccountID == "" || c.APIToken == "" { + if c.AccountID == "" || c.APIToken == "" { return nil, errors.New("CF_ACCOUNT_ID and CF_API_TOKEN environment variables are required") } @@ -104,4 +99,4 @@ func (c *CloudflareAI) Do(model string, arg *CfTextGenerationArg) (*http.Respons } req.Header.Set("Authorization", "Bearer "+c.APIToken) return httpClient.Do(req) -} \ No newline at end of file +} diff --git a/cloudflare_ai_test.go b/cloudflare_ai_test.go index 1ab169f..2514185 100644 --- a/cloudflare_ai_test.go +++ b/cloudflare_ai_test.go @@ -23,7 +23,7 @@ func TestReadFromCloudflareLama2(t *testing.T) { } // Send the POST request - response, err := cf.Do("@cf/meta/llama-2-7b-chat-fp8b", &CfTextGenerationArg{ + response, err := cf.Do("@cf/meta/llama-2-7b-chat-int8", &CfTextGenerationArg{ Stream: true, Messages: []CfTextGenerationMsg{ {Role: "system", Content: "You are a chatbot."}, diff --git a/example_test.go b/example_test.go index 2d4710d..2491263 100644 --- a/example_test.go +++ b/example_test.go @@ -1,12 +1,84 @@ package sseread_test import ( + "encoding/json" "fmt" - "github.com/mojocn/sseread" + "io" + "log" "net/http" + "os" "strings" + + "github.com/mojocn/sseread" ) +func ExampleDo() { + // Retrieve the account ID and API token from the environment variables + accountID := os.Getenv("CF_ACCOUNT_ID") + apiToken := os.Getenv("CF_API_TOKEN") + + cf := &sseread.CloudflareAI{ + AccountID: accountID, + APIToken: apiToken, + } + + // Send the POST request + response, err := cf.Do("@cf/meta/llama-2-7b-chat-fp8b", &sseread.CfTextGenerationArg{ + Stream: true, + Messages: []sseread.CfTextGenerationMsg{ + {Role: "system", Content: "You are a chatbot."}, + {Role: "user", Content: "What is your name?"}, + }}) + if err != nil { + fmt.Println(err) + return + } + + // Ensure the response body is closed after the function returns + defer response.Body.Close() + + // Check the response status code + if response.StatusCode != http.StatusOK { + all, err := io.ReadAll(response.Body) + if err != nil { + fmt.Println(err) + return + } + log.Fatal(string(all)) + return + } + + // Read the response body as Server-Sent Events + channel, err := sseread.ReadCh(response.Body) + if err != nil { + fmt.Println(err) + return + } + + // Initialize an empty string to store the full text of the responses + fulltext := "" + + // Iterate over the events from the channel + for event := range channel { + if event == nil || event.IsSkip() { + continue + } + + // Parse the JSON object from the event data + e := new(sseread.CfTextGenerationResponse) + err := json.Unmarshal(event.Data, e) + if err != nil { + log.Fatal(err, string(event.Data)) + } else { + // Append the response to the fulltext string + fulltext += e.Response + } + } + + // Log the full text of the responses + fmt.Println(fulltext) +} + // ExampleRead is a function that demonstrates how to read Server-Sent Events (SSE) from a specific URL. func ExampleRead() { // Send a GET request to the specified URL.