Skip to content

Commit 9323358

Browse files
authored
Update to watsonx-go Client (#10)
Signed-off-by: Robby <h0rv@users.noreply.github.com> Co-authored-by: Robby <h0rv@users.noreply.github.com>
1 parent cc529cf commit 9323358

File tree

8 files changed

+151
-107
lines changed

8 files changed

+151
-107
lines changed

README.md

Lines changed: 40 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,56 @@
1-
# go-watsonx
1+
# watsonx-go
22

3-
Zero dependency [watsonx](https://www.ibm.com/watsonx) API Client for Go
3+
`watsonx-go` is a [watsonx](https://www.ibm.com/watsonx) Client for Go
44

55
## Install
66

7-
Install:
8-
97
```sh
10-
go get -u github.com/h0rv/go-watsonx
8+
go get -u github.com/IBM/watsonx-go
119
```
1210

13-
Import:
11+
## Usage
1412

1513
```go
1614
import (
17-
wx "github.com/h0rv/go-watsonx/pkg/models"
15+
wx "github.com/IBM/watsonx-go/pkg/models"
1816
)
1917
```
2018

21-
## Example Usage
19+
### Example Usage
20+
21+
```sh
22+
export WATSONX_API_KEY="YOUR WATSONX API KEY"
23+
export WATSONX_PROJECT_ID="YOUR WATSONX PROJECT ID"
24+
```
25+
26+
Create a client:
2227

2328
```go
24-
model, _ := wx.NewModel(
25-
wx.WithIBMCloudAPIKey("YOUR IBM CLOUD API KEY"),
26-
wx.WithWatsonxProjectID("YOUR WATSONX PROJECT ID"),
27-
)
28-
29-
result, _ := model.GenerateText(
30-
"meta-llama/llama-3-70b-instruct",
31-
"Hi, who are you?",
32-
wx.WithTemperature(0.9),
33-
wx.WithTopP(.5),
34-
wx.WithTopK(10),
35-
wx.WithMaxNewTokens(512),
36-
)
37-
38-
println(result.Text)
29+
client, _ := wx.NewClient()
30+
```
31+
32+
Or pass in the required secrets directly:
33+
34+
```go
35+
client, err := wx.NewClient(
36+
wx.WithWatsonxAPIKey(apiKey),
37+
wx.WithWatsonxProjectID(projectID),
38+
)
39+
```
40+
41+
Generation:
42+
43+
```go
44+
result, _ := client.GenerateText(
45+
"meta-llama/llama-3-70b-instruct",
46+
"Hi, who are you?",
47+
wx.WithTemperature(0.9),
48+
wx.WithTopP(.5),
49+
wx.WithTopK(10),
50+
wx.WithMaxNewTokens(512),
51+
)
52+
53+
println(result.Text)
3954
```
4055

4156
## Development Setup
@@ -45,7 +60,7 @@ import (
4560
#### Setup
4661

4762
```sh
48-
export IBMCLOUD_API_KEY="YOUR IBM CLOUD API KEY"
63+
export WATSONX_API_KEY="YOUR WATSONX API KEY"
4964
export WATSONX_PROJECT_ID="YOUR WATSONX PROJECT ID"
5065
```
5166

@@ -65,5 +80,5 @@ git config --local core.hooksPath .githooks/
6580

6681
## Resources
6782

68-
- [watsonx Python SDK Docs](https://ibm.github.io/watson-machine-learning-sdk)
6983
- [watsonx REST API Docs](https://cloud.ibm.com/apidocs/watsonx-ai)
84+
- [watsonx Python SDK Docs](https://ibm.github.io/watson-machine-learning-sdk)

go.mod

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
module github.com/h0rv/go-watsonx
1+
module github.com/IBM/watsonx-go
22

3-
go 1.21.2
3+
go 1.21.4

pkg/internal/tests/models/generate_test.go

Lines changed: 45 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,34 +4,62 @@ import (
44
"os"
55
"testing"
66

7-
wx "github.com/h0rv/go-watsonx/pkg/models"
7+
wx "github.com/IBM/watsonx-go/pkg/models"
88
)
99

10-
func getModel(t *testing.T) *wx.Model {
11-
apiKey := os.Getenv(wx.WatsonxAPIKeyEnvVarName)
12-
projectID := os.Getenv(wx.WatsonxProjectIDEnvVarName)
10+
func TestClientCreationWithEnvVars(t *testing.T) {
11+
_, err := wx.NewClient()
12+
13+
if err != nil {
14+
t.Fatalf("Expected no error for creating client with environment variables, but got %v", err)
15+
}
16+
}
17+
18+
func TestClientCreationWithPassing(t *testing.T) {
19+
apiKey, projectID := os.Getenv(wx.WatsonxAPIKeyEnvVarName), os.Getenv(wx.WatsonxProjectIDEnvVarName)
20+
21+
if apiKey == "" {
22+
t.Fatal("No watsonx API key provided")
23+
}
24+
if projectID == "" {
25+
t.Fatal("No watsonx project ID provided")
26+
}
27+
28+
_, err := wx.NewClient(
29+
wx.WithWatsonxAPIKey(apiKey),
30+
wx.WithWatsonxProjectID(projectID),
31+
)
32+
33+
if err != nil {
34+
t.Fatalf("Expected no error for creating client with passing secrets, but got %v", err)
35+
}
36+
}
37+
38+
func getClient(t *testing.T) *wx.Client {
39+
apiKey, projectID := os.Getenv(wx.WatsonxAPIKeyEnvVarName), os.Getenv(wx.WatsonxProjectIDEnvVarName)
40+
1341
if apiKey == "" {
1442
t.Fatal("No watsonx API key provided")
1543
}
1644
if projectID == "" {
1745
t.Fatal("No watsonx project ID provided")
1846
}
1947

20-
model, err := wx.NewModel(
48+
client, err := wx.NewClient(
2149
wx.WithWatsonxAPIKey(apiKey),
2250
wx.WithWatsonxProjectID(projectID),
2351
)
2452
if err != nil {
25-
t.Fatalf("Failed to create model for testing. Error: %v", err)
53+
t.Fatalf("Failed to create client for testing. Error: %v", err)
2654
}
2755

28-
return model
56+
return client
2957
}
3058

3159
func TestEmptyPromptError(t *testing.T) {
32-
model := getModel(t)
60+
client := getClient(t)
3361

34-
_, err := model.GenerateText(
62+
_, err := client.GenerateText(
3563
"dumby model",
3664
"",
3765
)
@@ -41,9 +69,9 @@ func TestEmptyPromptError(t *testing.T) {
4169
}
4270

4371
func TestNilOptions(t *testing.T) {
44-
model := getModel(t)
72+
client := getClient(t)
4573

46-
_, err := model.GenerateText(
74+
_, err := client.GenerateText(
4775
"meta-llama/llama-3-70b-instruct",
4876
"What day is it?",
4977
nil,
@@ -54,9 +82,9 @@ func TestNilOptions(t *testing.T) {
5482
}
5583

5684
func TestValidPrompt(t *testing.T) {
57-
model := getModel(t)
85+
client := getClient(t)
5886

59-
_, err := model.GenerateText(
87+
_, err := client.GenerateText(
6088
"meta-llama/llama-3-70b-instruct",
6189
"Test prompt",
6290
)
@@ -66,9 +94,9 @@ func TestValidPrompt(t *testing.T) {
6694
}
6795

6896
func TestGenerateText(t *testing.T) {
69-
model := getModel(t)
97+
client := getClient(t)
7098

71-
result, err := model.GenerateText(
99+
result, err := client.GenerateText(
72100
"meta-llama/llama-3-70b-instruct",
73101
"Hi, who are you?",
74102
wx.WithTemperature(0.9),
@@ -85,9 +113,9 @@ func TestGenerateText(t *testing.T) {
85113
}
86114

87115
func TestGenerateTextWithNilOptions(t *testing.T) {
88-
model := getModel(t)
116+
client := getClient(t)
89117

90-
result, err := model.GenerateText(
118+
result, err := client.GenerateText(
91119
"meta-llama/llama-3-70b-instruct",
92120
"Who are you?",
93121
nil,

pkg/models/model.go renamed to pkg/models/client.go

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,13 @@
11
package models
22

3-
/*
4-
* https://ibm.github.io/watson-machine-learning-sdk/_modules/ibm_watson_machine_learning/foundation_models/model.html#Model
5-
*/
6-
73
import (
4+
"errors"
85
"fmt"
96
"net/http"
107
"os"
118
)
129

13-
type Model struct {
10+
type Client struct {
1411
url string
1512
region IBMCloudRegion
1613
apiVersion string
@@ -22,9 +19,9 @@ type Model struct {
2219
httpClient Doer
2320
}
2421

25-
func NewModel(options ...ModelOption) (*Model, error) {
22+
func NewClient(options ...ClientOption) (*Client, error) {
2623

27-
opts := defaulModelOptions()
24+
opts := defaulClientOptions()
2825
for _, opt := range options {
2926
if opt != nil {
3027
opt(opts)
@@ -36,13 +33,21 @@ func NewModel(options ...ModelOption) (*Model, error) {
3633
opts.URL = buildBaseURL(opts.Region)
3734
}
3835

39-
m := &Model{
36+
if opts.apiKey == "" {
37+
return nil, errors.New("no watsonx API key provided")
38+
}
39+
40+
if opts.projectID == "" {
41+
return nil, errors.New("no watsonx project ID provided")
42+
}
43+
44+
m := &Client{
4045
url: opts.URL,
4146
region: opts.Region,
4247
apiVersion: opts.APIVersion,
4348

4449
// token: set below
45-
apiKey: opts.watsonxAPIKey,
50+
apiKey: opts.apiKey,
4651
projectID: opts.projectID,
4752

4853
httpClient: &http.Client{},
@@ -57,15 +62,15 @@ func NewModel(options ...ModelOption) (*Model, error) {
5762
}
5863

5964
// CheckAndRefreshToken checks the IAM token if it expired; if it did, it refreshes it; nothing if not
60-
func (m *Model) CheckAndRefreshToken() error {
65+
func (m *Client) CheckAndRefreshToken() error {
6166
if m.token.Expired() {
6267
return m.RefreshToken()
6368
}
6469
return nil
6570
}
6671

6772
// RefreshToken generates and sets the model with a new token
68-
func (m *Model) RefreshToken() error {
73+
func (m *Client) RefreshToken() error {
6974
token, err := GenerateToken(m.httpClient, m.apiKey)
7075
if err != nil {
7176
return err
@@ -78,13 +83,13 @@ func buildBaseURL(region IBMCloudRegion) string {
7883
return fmt.Sprintf(BaseURLFormatStr, region)
7984
}
8085

81-
func defaulModelOptions() *ModelOptions {
82-
return &ModelOptions{
86+
func defaulClientOptions() *ClientOptions {
87+
return &ClientOptions{
8388
URL: "",
8489
Region: DefaultRegion,
8590
APIVersion: DefaultAPIVersion,
8691

87-
watsonxAPIKey: os.Getenv(WatsonxAPIKeyEnvVarName),
88-
projectID: os.Getenv(WatsonxProjectIDEnvVarName),
92+
apiKey: os.Getenv(WatsonxAPIKeyEnvVarName),
93+
projectID: os.Getenv(WatsonxProjectIDEnvVarName),
8994
}
9095
}

pkg/models/client_option.go

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
package models
2+
3+
type ClientOption func(*ClientOptions)
4+
5+
type ClientOptions struct {
6+
URL string
7+
Region IBMCloudRegion
8+
APIVersion string
9+
10+
apiKey WatsonxAPIKey
11+
projectID WatsonxProjectID
12+
}
13+
14+
func WithURL(url string) ClientOption {
15+
return func(o *ClientOptions) {
16+
o.URL = url
17+
}
18+
}
19+
20+
func WithRegion(region IBMCloudRegion) ClientOption {
21+
return func(o *ClientOptions) {
22+
o.Region = region
23+
}
24+
}
25+
26+
func WithAPIVersion(apiVersion string) ClientOption {
27+
return func(o *ClientOptions) {
28+
o.APIVersion = apiVersion
29+
}
30+
}
31+
32+
func WithWatsonxAPIKey(watsonxAPIKey WatsonxAPIKey) ClientOption {
33+
return func(o *ClientOptions) {
34+
o.apiKey = watsonxAPIKey
35+
}
36+
}
37+
38+
func WithWatsonxProjectID(projectID WatsonxProjectID) ClientOption {
39+
return func(o *ClientOptions) {
40+
o.projectID = projectID
41+
}
42+
}

pkg/models/generate.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ type generateTextResponse struct {
4949
}
5050

5151
// GenerateText generates completion text based on a given prompt and parameters
52-
func (m *Model) GenerateText(model, prompt string, options ...GenerateOption) (GenerateTextResult, error) {
52+
func (m *Client) GenerateText(model, prompt string, options ...GenerateOption) (GenerateTextResult, error) {
5353
m.CheckAndRefreshToken()
5454

5555
if prompt == "" {
@@ -86,7 +86,7 @@ func (m *Model) GenerateText(model, prompt string, options ...GenerateOption) (G
8686

8787
// generateTextRequest sends the generate request and handles the response using the http package.
8888
// Returns error on non-2XX response
89-
func (m *Model) generateTextRequest(payload GenerateTextPayload) (generateTextResponse, error) {
89+
func (m *Client) generateTextRequest(payload GenerateTextPayload) (generateTextResponse, error) {
9090
params := url.Values{
9191
"version": {m.apiVersion},
9292
}

0 commit comments

Comments
 (0)