Skip to content

Commit

Permalink
example of cloudflare ai
Browse files Browse the repository at this point in the history
  • Loading branch information
mojocn authored Mar 18, 2024
1 parent a9c1f39 commit 09e61c3
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 14 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/go.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4

- name: Set up Go
uses: actions/setup-go@v4
Expand Down
17 changes: 6 additions & 11 deletions cloudflare_ai.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,12 @@ type CfTextGenerationResponse struct {
}

type CfTextGenerationMsg struct {
Role string `json:"role"`
Role string `json:"role"`
Content string `json:"content"`
}


type CfTextGenerationArg struct{
Stream bool `json:"stream,omitempty"`
type CfTextGenerationArg struct {
Stream bool `json:"stream,omitempty"`
Messages []CfTextGenerationMsg `json:"messages,omitempty"`
}

Expand All @@ -31,17 +30,13 @@ func (c *CfTextGenerationArg) body() (io.ReadCloser, error) {
return io.NopCloser(buff), err
}


type CloudflareAI struct {
AccountID string
APIToken string

}

var httpClient = &http.Client{}



var modelsTextGeneration = []string{
//https://dash.cloudflare.com/0a76b889e644c012524110042e6f197e/ai/workers-ai
//page 1
Expand Down Expand Up @@ -81,11 +76,11 @@ func (c *CloudflareAI) modelCheck(model string) error {
return nil
}
}
return errors.New("model not found: "+model)
return errors.New("model not found: " + model)
}

func (c *CloudflareAI) Do(model string, arg *CfTextGenerationArg) (*http.Response, error) {
if c.AccountID == "" || c.APIToken == "" {
if c.AccountID == "" || c.APIToken == "" {
return nil, errors.New("CF_ACCOUNT_ID and CF_API_TOKEN environment variables are required")
}

Expand All @@ -104,4 +99,4 @@ func (c *CloudflareAI) Do(model string, arg *CfTextGenerationArg) (*http.Respons
}
req.Header.Set("Authorization", "Bearer "+c.APIToken)
return httpClient.Do(req)
}
}
2 changes: 1 addition & 1 deletion cloudflare_ai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func TestReadFromCloudflareLama2(t *testing.T) {
}

// Send the POST request
response, err := cf.Do("@cf/meta/llama-2-7b-chat-fp8b", &CfTextGenerationArg{
response, err := cf.Do("@cf/meta/llama-2-7b-chat-int8", &CfTextGenerationArg{
Stream: true,
Messages: []CfTextGenerationMsg{
{Role: "system", Content: "You are a chatbot."},
Expand Down
74 changes: 73 additions & 1 deletion example_test.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,84 @@
package sseread_test

import (
"encoding/json"
"fmt"
"github.com/mojocn/sseread"
"io"
"log"
"net/http"
"os"
"strings"

"github.com/mojocn/sseread"
)

func ExampleDo() {
// Retrieve the account ID and API token from the environment variables
accountID := os.Getenv("CF_ACCOUNT_ID")
apiToken := os.Getenv("CF_API_TOKEN")

cf := &sseread.CloudflareAI{
AccountID: accountID,
APIToken: apiToken,
}

// Send the POST request
response, err := cf.Do("@cf/meta/llama-2-7b-chat-fp8b", &sseread.CfTextGenerationArg{
Stream: true,
Messages: []sseread.CfTextGenerationMsg{
{Role: "system", Content: "You are a chatbot."},
{Role: "user", Content: "What is your name?"},
}})
if err != nil {
fmt.Println(err)
return
}

// Ensure the response body is closed after the function returns
defer response.Body.Close()

// Check the response status code
if response.StatusCode != http.StatusOK {
all, err := io.ReadAll(response.Body)
if err != nil {
fmt.Println(err)
return
}
log.Fatal(string(all))
return
}

// Read the response body as Server-Sent Events
channel, err := sseread.ReadCh(response.Body)
if err != nil {
fmt.Println(err)
return
}

// Initialize an empty string to store the full text of the responses
fulltext := ""

// Iterate over the events from the channel
for event := range channel {
if event == nil || event.IsSkip() {
continue
}

// Parse the JSON object from the event data
e := new(sseread.CfTextGenerationResponse)
err := json.Unmarshal(event.Data, e)
if err != nil {
log.Fatal(err, string(event.Data))
} else {
// Append the response to the fulltext string
fulltext += e.Response
}
}

// Log the full text of the responses
fmt.Println(fulltext)
}

// ExampleRead is a function that demonstrates how to read Server-Sent Events (SSE) from a specific URL.
func ExampleRead() {
// Send a GET request to the specified URL.
Expand Down

0 comments on commit 09e61c3

Please sign in to comment.