Skip to content

Commit a9c1f39

Browse files
authored
cloudflare ai worker
1 parent cf442bc commit a9c1f39

File tree

2 files changed

+117
-21
lines changed

2 files changed

+117
-21
lines changed

cloudflare_ai.go

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
package sseread
2+
3+
import (
4+
"bytes"
5+
"encoding/json"
6+
"errors"
7+
"io"
8+
"net/http"
9+
)
10+
11+
// https://developers.cloudflare.com/workers-ai/models/zephyr-7b-beta-awq/#using-streaming
12+
type CfTextGenerationResponse struct {
13+
Response string `json:"response"`
14+
P string `json:"p"`
15+
}
16+
17+
type CfTextGenerationMsg struct {
18+
Role string `json:"role"`
19+
Content string `json:"content"`
20+
}
21+
22+
23+
type CfTextGenerationArg struct{
24+
Stream bool `json:"stream,omitempty"`
25+
Messages []CfTextGenerationMsg `json:"messages,omitempty"`
26+
}
27+
28+
func (c *CfTextGenerationArg) body() (io.ReadCloser, error) {
29+
buff := bytes.NewBuffer(nil)
30+
err := json.NewEncoder(buff).Encode(c)
31+
return io.NopCloser(buff), err
32+
}
33+
34+
35+
type CloudflareAI struct {
36+
AccountID string
37+
APIToken string
38+
39+
}
40+
41+
var httpClient = &http.Client{}
42+
43+
44+
45+
var modelsTextGeneration = []string{
46+
//https://dash.cloudflare.com/0a76b889e644c012524110042e6f197e/ai/workers-ai
47+
//page 1
48+
"@cf/meta/llama-2-7b-chat-fp16",
49+
"@cf/mistral/mistral-7b-instruct-v0.1",
50+
"@cf/meta/llama-2-7b-chat-int8",
51+
"@cf/qwen/qwen1.5-0.5b-chat",
52+
"@hf/thebloke/llamaguard-7b-awq",
53+
"@hf/thebloke/neural-chat-7b-v3-1-awq",
54+
"@cf/deepseek-ai/deepseek-math-7b-base",
55+
"@cf/tinyllama/tinyllama-1.1b-chat-v1.0",
56+
"@hf/thebloke/orca-2-13b-awq",
57+
"@hf/thebloke/codellama-7b-instruct-awq",
58+
//page 2
59+
"@cf/thebloke/discolm-german-7b-v1-awq",
60+
"@hf/thebloke/mistral-7b-instruct-v0.1-awq",
61+
"@hf/thebloke/openchat_3.5-awq",
62+
"@cf/qwen/qwen1.5-7b-chat-awq",
63+
"@hf/thebloke/llama-2-13b-chat-awq",
64+
"@hf/thebloke/deepseek-coder-6.7b-base-awq",
65+
"@hf/thebloke/openhermes-2.5-mistral-7b-awq",
66+
"@hf/thebloke/deepseek-coder-6.7b-instruct-awq",
67+
"@cf/deepseek-ai/deepseek-math-7b-instruct",
68+
"@cf/tiiuae/falcon-7b-instruct",
69+
//page 3
70+
"@hf/thebloke/zephyr-7b-beta-awq",
71+
"@cf/qwen/qwen1.5-1.8b-chat",
72+
"@cf/defog/sqlcoder-7b-2",
73+
"@cf/microsoft/phi-2",
74+
"@cf/qwen/qwen1.5-14b-chat-awq",
75+
"@cf/openchat/openchat-3.5-0106",
76+
}
77+
78+
func (c *CloudflareAI) modelCheck(model string) error {
79+
for _, v := range modelsTextGeneration {
80+
if v == model {
81+
return nil
82+
}
83+
}
84+
return errors.New("model not found: "+model)
85+
}
86+
87+
func (c *CloudflareAI) Do(model string, arg *CfTextGenerationArg) (*http.Response, error) {
88+
if c.AccountID == "" || c.APIToken == "" {
89+
return nil, errors.New("CF_ACCOUNT_ID and CF_API_TOKEN environment variables are required")
90+
}
91+
92+
if err := c.modelCheck(model); err != nil {
93+
return nil, err
94+
}
95+
96+
body, err := arg.body()
97+
if err != nil {
98+
return nil, err
99+
}
100+
101+
req, err := http.NewRequest("POST", "https://api.cloudflare.com/client/v4/accounts/"+c.AccountID+"/ai/run/"+model, body)
102+
if err != nil {
103+
return nil, err
104+
}
105+
req.Header.Set("Authorization", "Bearer "+c.APIToken)
106+
return httpClient.Do(req)
107+
}

cloudflare_ai_test.go

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,13 @@
11
package sseread
22

33
import (
4-
"bytes"
54
"encoding/json"
65
"io"
76
"net/http"
87
"os"
98
"testing"
109
)
1110

12-
// https://developers.cloudflare.com/workers-ai/models/zephyr-7b-beta-awq/#using-streaming
13-
type llamaMsg struct {
14-
Response string `json:"response"`
15-
P string `json:"p"`
16-
}
17-
1811
// TestReadFromCloudflareLama2 is a test function for the ReadCh function in the sseread package.
1912
// It sends a POST request to the Cloudflare API and reads the response body as Server-Sent Events.
2013
// For each event, it parses the JSON object from the event data and appends the response to the fulltext string.
@@ -23,23 +16,19 @@ func TestReadFromCloudflareLama2(t *testing.T) {
2316
// Retrieve the account ID and API token from the environment variables
2417
accountID := os.Getenv("CF_ACCOUNT_ID")
2518
apiToken := os.Getenv("CF_API_TOKEN")
26-
if accountID == "" || apiToken == "" {
27-
t.Fatal("CF_ACCOUNT_ID and CF_API_TOKEN environment variables are required")
28-
}
29-
// Create a buffer with the request body
30-
buff := bytes.NewBufferString(`{ "stream":true,"messages": [{ "role": "system", "content": "You are a friendly assistant" }, { "role": "user", "content": "Why is pizza so good" }]}`)
3119

32-
// Create a new POST request to the Cloudflare API
33-
req, err := http.NewRequest("POST", "https://api.cloudflare.com/client/v4/accounts/"+accountID+"/ai/run/@cf/meta/llama-2-7b-chat-int8", buff)
34-
if err != nil {
35-
t.Fatal(err)
20+
cf := &CloudflareAI{
21+
AccountID: accountID,
22+
APIToken: apiToken,
3623
}
3724

38-
// Set the Authorization header with the API token
39-
req.Header.Set("Authorization", "Bearer "+apiToken)
40-
4125
// Send the POST request
42-
response, err := http.DefaultClient.Do(req)
26+
response, err := cf.Do("@cf/meta/llama-2-7b-chat-fp8b", &CfTextGenerationArg{
27+
Stream: true,
28+
Messages: []CfTextGenerationMsg{
29+
{Role: "system", Content: "You are a chatbot."},
30+
{Role: "user", Content: "What is your name?"},
31+
}})
4332
if err != nil {
4433
t.Fatal(err)
4534
}
@@ -73,7 +62,7 @@ func TestReadFromCloudflareLama2(t *testing.T) {
7362
}
7463

7564
// Parse the JSON object from the event data
76-
e := new(llamaMsg)
65+
e := new(CfTextGenerationResponse)
7766
err := json.Unmarshal(event.Data, e)
7867
if err != nil {
7968
t.Error(err, string(event.Data))

0 commit comments

Comments
 (0)