@@ -15,30 +15,45 @@ const (
15
15
GenerateTextEndpoint string = GenerationEndpoint + "/text"
16
16
)
17
17
18
- type GenerateResult struct {
19
- GeneratedText string `json:"generated_text"`
20
- StopReason string `json:"stop_reason"`
18
+ type StopReason = string
19
+
20
+ const (
21
+ NotFinished StopReason = "NOT_FINISHED" // Possibly more tokens to be streamed
22
+ MaxTokens StopReason = "MAX_TOKENS" // Maximum requested tokens reached
23
+ EndOfSequenceToken StopReason = "EOS_TOKEN" // End of sequence token encountered
24
+ Cancelled StopReason = "CANCELLED" // Request canceled by the client
25
+ TimeLimit StopReason = "TIME_LIMIT" // Time limit reached
26
+ StopSequence StopReason = "STOP_SEQUENCE" // Stop sequence encountered
27
+ TokenLimit StopReason = "TOKEN_LIMIT" // Token limit reached
28
+ Error StopReason = "ERROR" // Error encountered
29
+ )
30
+
31
+ type GenerateTextResult struct {
32
+ Text string `json:"generated_text"`
33
+ GeneratedTokenCount int `json:"generated_token_count"`
34
+ InputTokenCount int `json:"input_token_count"`
35
+ StopReason StopReason `json:"stop_reason"`
21
36
}
22
37
23
- type GeneratePayload struct {
38
+ type GenerateTextPayload struct {
24
39
ProjectID string `json:"project_id"`
25
40
Model string `json:"model_id"`
26
41
Prompt string `json:"input"`
27
42
Parameters * GenerateOptions `json:"parameters,omitempty"`
28
43
}
29
44
30
- type generateResponse struct {
31
- Status string `json:"status"`
32
- StatusCode int `json:"status_code"`
33
- Results []GenerateResult `json:"results"`
45
+ type generateTextResponse struct {
46
+ Status string `json:"status"`
47
+ StatusCode int `json:"status_code"`
48
+ Results []GenerateTextResult `json:"results"`
34
49
}
35
50
36
51
// GenerateText generates completion text based on a given prompt and parameters
37
- func (m * Model ) GenerateText (prompt string , options ... GenerateOption ) (string , error ) {
52
+ func (m * Model ) GenerateText (prompt string , options ... GenerateOption ) (GenerateTextResult , error ) {
38
53
m .CheckAndRefreshToken ()
39
54
40
55
if prompt == "" {
41
- return "" , errors .New ("prompt cannot be empty" )
56
+ return GenerateTextResult {} , errors .New ("prompt cannot be empty" )
42
57
}
43
58
44
59
opts := & GenerateOptions {}
@@ -48,7 +63,7 @@ func (m *Model) GenerateText(prompt string, options ...GenerateOption) (string,
48
63
}
49
64
}
50
65
51
- payload := GeneratePayload {
66
+ payload := GenerateTextPayload {
52
67
ProjectID : m .projectID ,
53
68
Model : m .modelType ,
54
69
Prompt : prompt ,
@@ -57,17 +72,21 @@ func (m *Model) GenerateText(prompt string, options ...GenerateOption) (string,
57
72
58
73
response , err := m .generateTextRequest (payload )
59
74
if err != nil {
60
- return "" , err
75
+ return GenerateTextResult {}, err
76
+ }
77
+
78
+ if len (response .Results ) == 0 {
79
+ return GenerateTextResult {}, errors .New ("no result recieved" )
61
80
}
62
81
63
- result := response .Results [0 ]. GeneratedText
82
+ result := response .Results [0 ]
64
83
65
84
return result , nil
66
85
}
67
86
68
87
// generateTextRequest sends the generate request and handles the response using the http package.
69
88
// Returns error on non-2XX response
70
- func (m * Model ) generateTextRequest (payload GeneratePayload ) (generateResponse , error ) {
89
+ func (m * Model ) generateTextRequest (payload GenerateTextPayload ) (generateTextResponse , error ) {
71
90
params := url.Values {
72
91
"version" : {m .apiVersion },
73
92
}
@@ -81,37 +100,37 @@ func (m *Model) generateTextRequest(payload GeneratePayload) (generateResponse,
81
100
82
101
payloadJSON , err := json .Marshal (payload )
83
102
if err != nil {
84
- return generateResponse {}, err
103
+ return generateTextResponse {}, err
85
104
}
86
105
87
106
req , err := http .NewRequest (http .MethodPost , generateTextURL .String (), bytes .NewBuffer (payloadJSON ))
88
107
if err != nil {
89
- return generateResponse {}, err
108
+ return generateTextResponse {}, err
90
109
}
91
110
92
111
req .Header .Set ("Content-Type" , "application/json" )
93
112
req .Header .Set ("Authorization" , "Bearer " + m .token .value )
94
113
95
114
res , err := m .httpClient .Do (req )
96
115
if err != nil {
97
- return generateResponse {}, err
116
+ return generateTextResponse {}, err
98
117
}
99
118
100
119
statusCode := res .StatusCode
101
120
102
121
if statusCode < 200 || statusCode >= 300 {
103
122
body , err := io .ReadAll (res .Body )
104
123
if err != nil {
105
- return generateResponse {}, fmt .Errorf ("request failed with status code %d" , statusCode )
124
+ return generateTextResponse {}, fmt .Errorf ("request failed with status code %d" , statusCode )
106
125
}
107
- return generateResponse {}, fmt .Errorf ("request failed with status code %d and error %s" , statusCode , body )
126
+ return generateTextResponse {}, fmt .Errorf ("request failed with status code %d and error %s" , statusCode , body )
108
127
}
109
128
defer res .Body .Close ()
110
129
111
- var generateRes generateResponse
130
+ var generateRes generateTextResponse
112
131
113
132
if err := json .NewDecoder (res .Body ).Decode (& generateRes ); err != nil {
114
- return generateResponse {}, err
133
+ return generateTextResponse {}, err
115
134
}
116
135
117
136
return generateRes , nil
0 commit comments