Skip to content

Commit d119ab0

Browse files
committed
Update GenerateText result struct
1 parent ebae778 commit d119ab0

File tree

3 files changed

+43
-24
lines changed

3 files changed

+43
-24
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ import (
3636
wx.WithDecodingMethod(wx.Greedy),
3737
)
3838

39-
println(result)
39+
println(result.Text)
4040
```
4141

4242
## Development Setup

models/generate.go

Lines changed: 40 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,30 +15,45 @@ const (
1515
GenerateTextEndpoint string = GenerationEndpoint + "/text"
1616
)
1717

18-
type GenerateResult struct {
19-
GeneratedText string `json:"generated_text"`
20-
StopReason string `json:"stop_reason"`
18+
type StopReason = string
19+
20+
const (
21+
NotFinished StopReason = "NOT_FINISHED" // Possibly more tokens to be streamed
22+
MaxTokens StopReason = "MAX_TOKENS" // Maximum requested tokens reached
23+
EndOfSequenceToken StopReason = "EOS_TOKEN" // End of sequence token encountered
24+
Cancelled StopReason = "CANCELLED" // Request canceled by the client
25+
TimeLimit StopReason = "TIME_LIMIT" // Time limit reached
26+
StopSequence StopReason = "STOP_SEQUENCE" // Stop sequence encountered
27+
TokenLimit StopReason = "TOKEN_LIMIT" // Token limit reached
28+
Error StopReason = "ERROR" // Error encountered
29+
)
30+
31+
type GenerateTextResult struct {
32+
Text string `json:"generated_text"`
33+
GeneratedTokenCount int `json:"generated_token_count"`
34+
InputTokenCount int `json:"input_token_count"`
35+
StopReason StopReason `json:"stop_reason"`
2136
}
2237

23-
type GeneratePayload struct {
38+
type GenerateTextPayload struct {
2439
ProjectID string `json:"project_id"`
2540
Model string `json:"model_id"`
2641
Prompt string `json:"input"`
2742
Parameters *GenerateOptions `json:"parameters,omitempty"`
2843
}
2944

30-
type generateResponse struct {
31-
Status string `json:"status"`
32-
StatusCode int `json:"status_code"`
33-
Results []GenerateResult `json:"results"`
45+
type generateTextResponse struct {
46+
Status string `json:"status"`
47+
StatusCode int `json:"status_code"`
48+
Results []GenerateTextResult `json:"results"`
3449
}
3550

3651
// GenerateText generates completion text based on a given prompt and parameters
37-
func (m *Model) GenerateText(prompt string, options ...GenerateOption) (string, error) {
52+
func (m *Model) GenerateText(prompt string, options ...GenerateOption) (GenerateTextResult, error) {
3853
m.CheckAndRefreshToken()
3954

4055
if prompt == "" {
41-
return "", errors.New("prompt cannot be empty")
56+
return GenerateTextResult{}, errors.New("prompt cannot be empty")
4257
}
4358

4459
opts := &GenerateOptions{}
@@ -48,7 +63,7 @@ func (m *Model) GenerateText(prompt string, options ...GenerateOption) (string,
4863
}
4964
}
5065

51-
payload := GeneratePayload{
66+
payload := GenerateTextPayload{
5267
ProjectID: m.projectID,
5368
Model: m.modelType,
5469
Prompt: prompt,
@@ -57,17 +72,21 @@ func (m *Model) GenerateText(prompt string, options ...GenerateOption) (string,
5772

5873
response, err := m.generateTextRequest(payload)
5974
if err != nil {
60-
return "", err
75+
return GenerateTextResult{}, err
76+
}
77+
78+
if len(response.Results) == 0 {
79+
return GenerateTextResult{}, errors.New("no result recieved")
6180
}
6281

63-
result := response.Results[0].GeneratedText
82+
result := response.Results[0]
6483

6584
return result, nil
6685
}
6786

6887
// generateTextRequest sends the generate request and handles the response using the http package.
6988
// Returns error on non-2XX response
70-
func (m *Model) generateTextRequest(payload GeneratePayload) (generateResponse, error) {
89+
func (m *Model) generateTextRequest(payload GenerateTextPayload) (generateTextResponse, error) {
7190
params := url.Values{
7291
"version": {m.apiVersion},
7392
}
@@ -81,37 +100,37 @@ func (m *Model) generateTextRequest(payload GeneratePayload) (generateResponse,
81100

82101
payloadJSON, err := json.Marshal(payload)
83102
if err != nil {
84-
return generateResponse{}, err
103+
return generateTextResponse{}, err
85104
}
86105

87106
req, err := http.NewRequest(http.MethodPost, generateTextURL.String(), bytes.NewBuffer(payloadJSON))
88107
if err != nil {
89-
return generateResponse{}, err
108+
return generateTextResponse{}, err
90109
}
91110

92111
req.Header.Set("Content-Type", "application/json")
93112
req.Header.Set("Authorization", "Bearer "+m.token.value)
94113

95114
res, err := m.httpClient.Do(req)
96115
if err != nil {
97-
return generateResponse{}, err
116+
return generateTextResponse{}, err
98117
}
99118

100119
statusCode := res.StatusCode
101120

102121
if statusCode < 200 || statusCode >= 300 {
103122
body, err := io.ReadAll(res.Body)
104123
if err != nil {
105-
return generateResponse{}, fmt.Errorf("request failed with status code %d", statusCode)
124+
return generateTextResponse{}, fmt.Errorf("request failed with status code %d", statusCode)
106125
}
107-
return generateResponse{}, fmt.Errorf("request failed with status code %d and error %s", statusCode, body)
126+
return generateTextResponse{}, fmt.Errorf("request failed with status code %d and error %s", statusCode, body)
108127
}
109128
defer res.Body.Close()
110129

111-
var generateRes generateResponse
130+
var generateRes generateTextResponse
112131

113132
if err := json.NewDecoder(res.Body).Decode(&generateRes); err != nil {
114-
return generateResponse{}, err
133+
return generateTextResponse{}, err
115134
}
116135

117136
return generateRes, nil

models/test/generate_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ func TestGenerateText(t *testing.T) {
7272
if err != nil {
7373
t.Fatalf("Expected no error, but got an error: %v", err)
7474
}
75-
if result == "" {
75+
if result.Text == "" {
7676
t.Fatal("Expected a result, but got an empty string")
7777
}
7878
}
@@ -88,7 +88,7 @@ func TestGenerateTextWithNilOptions(t *testing.T) {
8888
if err != nil {
8989
t.Fatalf("Expected no error, but got an error: %v", err)
9090
}
91-
if result == "" {
91+
if result.Text == "" {
9292
t.Fatal("Expected a result, but got an empty string")
9393
}
9494
}

0 commit comments

Comments
 (0)