From 012b79b8d3c10a37f0b7ef14bb54a2b02edb3bf3 Mon Sep 17 00:00:00 2001 From: Ryan Slade Date: Fri, 19 Jan 2024 12:22:56 +0100 Subject: [PATCH] Add support for passing Workspace, Branch and Region via ClientOptions --- xata/client_options.go | 31 ++++++++++++--- xata/client_options_test.go | 29 ++++++++++---- xata/databases_client.go | 2 +- xata/utils.go | 45 +++++++++++++++------- xata/utils_test.go | 76 ++++++++++++++++++++++++------------- xata/workspaces_client.go | 2 +- 6 files changed, 130 insertions(+), 55 deletions(-) diff --git a/xata/client_options.go b/xata/client_options.go index a75cec6..563dbab 100644 --- a/xata/client_options.go +++ b/xata/client_options.go @@ -13,10 +13,13 @@ type httpClient interface { } type ClientOptions struct { - BaseURL string - HTTPClient httpClient - HTTPHeader http.Header - Bearer string + BaseURL string + HTTPClient httpClient + HTTPHeader http.Header + Bearer string + WorkspaceID string + Region string + Branch string } func consolidateClientOptionsForCore(opts ...ClientOption) (*ClientOptions, error) { @@ -56,7 +59,7 @@ func consolidateClientOptionsForWorkspace(opts ...ClientOption) (*ClientOptions, cliOpts.HTTPClient = http.DefaultClient } - dbCfg, err := loadDatabaseConfig() + dbCfg, err := loadDatabaseConfig(cliOpts) if err != nil && cliOpts.BaseURL == "" { return nil, nil, err } @@ -105,3 +108,21 @@ func WithBaseURL(baseURL string) func(options *ClientOptions) { options.BaseURL = baseURL } } + +func WithWorkspaceID(workspaceID string) func(options *ClientOptions) { + return func(options *ClientOptions) { + options.WorkspaceID = workspaceID + } +} + +func WithRegion(region string) func(options *ClientOptions) { + return func(options *ClientOptions) { + options.Region = region + } +} + +func WithBranch(branch string) func(options *ClientOptions) { + return func(options *ClientOptions) { + options.Branch = branch + } +} diff --git a/xata/client_options_test.go b/xata/client_options_test.go index dc3e827..2c0d430 100644 --- a/xata/client_options_test.go +++ b/xata/client_options_test.go @@ -10,22 +10,35 @@ import ( generatedwrapper "github.com/xataio/xata-go/xata" ) -func TestWithAPIToken(t *testing.T) { - t.Run("should use the provided API key in client options", func(t *testing.T) { +func TestClientOptions(t *testing.T) { + t.Run("WithAPIKey", func(t *testing.T) { c := &generatedwrapper.ClientOptions{} apiToken := "my-token" generatedwrapper.WithAPIKey("my-token")(c) - assert.Equal(t, apiToken, c.Bearer) }) -} - -func TestWithHTTPClient(t *testing.T) { - t.Run("should use the provided HTTP client in client options", func(t *testing.T) { + t.Run("WithHTTPClient", func(t *testing.T) { c := &generatedwrapper.ClientOptions{} cli := &http.Client{} generatedwrapper.WithHTTPClient(cli)(c) - assert.Equal(t, cli, c.HTTPClient) }) + t.Run("WithWorkspaceID", func(t *testing.T) { + c := &generatedwrapper.ClientOptions{} + workspaceID := "workspace-123" + generatedwrapper.WithWorkspaceID(workspaceID)(c) + assert.Equal(t, workspaceID, c.WorkspaceID) + }) + t.Run("WithBranch", func(t *testing.T) { + c := &generatedwrapper.ClientOptions{} + branch := "branch-123" + generatedwrapper.WithBranch(branch)(c) + assert.Equal(t, branch, c.Branch) + }) + t.Run("WithRegion", func(t *testing.T) { + c := &generatedwrapper.ClientOptions{} + region := "region-123" + generatedwrapper.WithRegion(region)(c) + assert.Equal(t, region, c.Region) + }) } diff --git a/xata/databases_client.go b/xata/databases_client.go index 3174d4f..5219db2 100644 --- a/xata/databases_client.go +++ b/xata/databases_client.go @@ -143,7 +143,7 @@ func NewDatabasesClient(opts ...ClientOption) (DatabasesClient, error) { return nil, err } - dbCfg, err := loadDatabaseConfig() + dbCfg, err := loadDatabaseConfig(cliOpts) if err != nil { // No err, because the config values can be provided by the users. log.Println(err) diff --git a/xata/utils.go b/xata/utils.go index 3fecee3..630d2e0 100644 --- a/xata/utils.go +++ b/xata/utils.go @@ -138,7 +138,7 @@ func parseDatabaseURL(rawURL string) (databaseConfig, error) { } if db.branchName == "" { - db.branchName = getBranchName() + db.branchName = getBranchName(nil) } return db, err @@ -188,39 +188,58 @@ func getEnvVar(name string, defaultValue string) string { return defaultValue } -// getBranchName retrieves the branch name. -// If not found, falls back to defaultBranchName -func getBranchName() string { +// getBranchName retrieves the branch name. If not found, falls back to defaultBranchName. +func getBranchName(opts *ClientOptions) string { + if opts != nil && opts.Branch != "" { + return opts.Branch + } return getEnvVar(EnvXataBranch, defaultBranchName) } -// Get the region if the corresponding env var `XATA_REGION` is set -// otherwise return the default region: us-east-1 -func getRegion() string { +// getRegion gets the region if the corresponding env var `XATA_REGION` is set otherwise return +// defaultRegion. +func getRegion(opts *ClientOptions) string { + if opts != nil && opts.Region != "" { + return opts.Region + } return getEnvVar(EnvXataRegion, defaultRegion) } +// getWorkspaceID gets the workspace id from opts and if empty, gets it from the `XATA_WORKSPACE_ID` +// environment variable +func getWorkspaceID(opts *ClientOptions) string { + if opts != nil && opts.WorkspaceID != "" { + return opts.WorkspaceID + } + return getEnvVar(EnvXataWorkspaceID, "") +} + // loadDatabaseConfig will return config with defaults if the error is not nil. -func loadDatabaseConfig() (databaseConfig, error) { +func loadDatabaseConfig(cliOpts *ClientOptions) (databaseConfig, error) { defaultDBConfig := databaseConfig{ region: defaultRegion, branchName: defaultBranchName, domainWorkspace: defaultDataPlaneDomain, } - // Setup with env var - // XATA_WORKSPACE_ID to set the workspace ID - wsID := getEnvVar(EnvXataWorkspaceID, "") + // Config can come from three places with differing priorities. The order from highest to lowest + // priority is: + // 1. Code via ClientOptions + // 2. Environment variables + // 3. Config files + + wsID := getWorkspaceID(cliOpts) if wsID != "" { db := databaseConfig{ workspaceID: wsID, - region: getRegion(), - branchName: getBranchName(), + region: getRegion(cliOpts), + branchName: getBranchName(cliOpts), domainWorkspace: defaultDataPlaneDomain, } return db, nil } + // Config not found in code or environment variables, fall back to config files cfg, err := loadConfig(configFileName) if err != nil { return defaultDBConfig, err diff --git a/xata/utils_test.go b/xata/utils_test.go index b053fd4..f50d2d9 100644 --- a/xata/utils_test.go +++ b/xata/utils_test.go @@ -29,29 +29,31 @@ func TestClientOptions_getAPIKey(t *testing.T) { func Test_getBranchName(t *testing.T) { // default state t.Run("should be default branch name", func(t *testing.T) { - gotBranchName := getBranchName() + gotBranchName := getBranchName(nil) assert.Equal(t, gotBranchName, defaultBranchName) }) setBranchName := "feature-042" - err := os.Setenv(EnvXataBranch, setBranchName) - if err != nil { - t.Fatal(err) - } + setEnvForTests(t, EnvXataBranch, setBranchName) // from env var t.Run("should be branch name from env var", func(t *testing.T) { - gotBranchName := getBranchName() - assert.Equal(t, gotBranchName, setBranchName) + gotBranchName := getBranchName(nil) + assert.Equal(t, setBranchName, gotBranchName) }) - t.Cleanup(func() { os.Unsetenv(EnvXataBranch) }) + // from ClientOptions + t.Run("from ClientOptions", func(t *testing.T) { + want := "branch-from-opts" + got := getBranchName(&ClientOptions{Branch: want}) + assert.Equal(t, want, got) + }) } func Test_getRegion(t *testing.T) { // default state t.Run("should be default region", func(t *testing.T) { - gotRegion := getRegion() + gotRegion := getRegion(nil) assert.Equal(t, gotRegion, defaultRegion) }) @@ -63,8 +65,14 @@ func Test_getRegion(t *testing.T) { // from env var t.Run("should be region from the env var", func(t *testing.T) { - gotRegion := getRegion() - assert.Equal(t, gotRegion, setRegion) + gotRegion := getRegion(nil) + assert.Equal(t, setRegion, gotRegion) + }) + + t.Run("should be region from ClientOptions", func(t *testing.T) { + wantRegion := "region-options" + gotRegion := getRegion(&ClientOptions{Region: wantRegion}) + assert.Equal(t, wantRegion, gotRegion) }) t.Cleanup(func() { os.Unsetenv(EnvXataRegion) }) @@ -176,14 +184,31 @@ func Test_loadConfig(t *testing.T) { func Test_loadDatabaseConfig_with_envvars(t *testing.T) { setWsId := "workspace-0lac00" - err := os.Setenv(EnvXataWorkspaceID, setWsId) - if err != nil { - t.Fatal(err) + + assertClientOptions := func(t *testing.T) { + wsID := "workspace-fco" + opts := &ClientOptions{ + WorkspaceID: wsID, + } + dbCfg, err := loadDatabaseConfig(opts) + assert.NoError(t, err) + assert.Equal(t, wsID, dbCfg.workspaceID) + assert.Equal(t, defaultBranchName, dbCfg.branchName) + assert.Equal(t, defaultRegion, dbCfg.region) + assert.Equal(t, defaultDataPlaneDomain, dbCfg.domainWorkspace) } + // test workspace is from ClientOptions + t.Run("load config from ClientOptions", assertClientOptions) + + setEnvForTests(t, EnvXataWorkspaceID, setWsId) + + // Check again after environment variable set + t.Run("config from ClientOptions takes precedence", assertClientOptions) + // test workspace id from env var t.Run("load config from WORKSPACE_ID env var", func(t *testing.T) { - dbCfg, err := loadDatabaseConfig() + dbCfg, err := loadDatabaseConfig(nil) assert.NoError(t, err) assert.Equal(t, setWsId, dbCfg.workspaceID) assert.Equal(t, defaultBranchName, dbCfg.branchName) @@ -192,28 +217,25 @@ func Test_loadDatabaseConfig_with_envvars(t *testing.T) { }) setBranch := "branch123" - err2 := os.Setenv(EnvXataBranch, setBranch) - if err2 != nil { - t.Fatal(err2) - } + setEnvForTests(t, EnvXataBranch, setBranch) setRegion := "ap-southeast-16" - err3 := os.Setenv(EnvXataRegion, setRegion) - if err3 != nil { - t.Fatal(err3) - } + setEnvForTests(t, EnvXataRegion, setRegion) // with branch and region env vars t.Run("load config from XATA_WORKSPACE_ID, XATA_REGION and XATA_BRANCH env vars", func(t *testing.T) { - dbCfg, err := loadDatabaseConfig() + dbCfg, err := loadDatabaseConfig(nil) assert.NoError(t, err) assert.Equal(t, setWsId, dbCfg.workspaceID) assert.Equal(t, setBranch, dbCfg.branchName) assert.Equal(t, setRegion, dbCfg.region) }) +} +func setEnvForTests(t *testing.T, key, value string) { + t.Helper() + err := os.Setenv(key, value) + assert.NoError(t, err) t.Cleanup(func() { - assert.NoError(t, os.Unsetenv(EnvXataWorkspaceID)) - assert.NoError(t, os.Unsetenv(EnvXataBranch)) - assert.NoError(t, os.Unsetenv(EnvXataRegion)) + assert.NoError(t, os.Unsetenv(key)) }) } diff --git a/xata/workspaces_client.go b/xata/workspaces_client.go index 93efcd6..cf9f841 100644 --- a/xata/workspaces_client.go +++ b/xata/workspaces_client.go @@ -79,7 +79,7 @@ func NewWorkspacesClient(opts ...ClientOption) (WorkspacesClient, error) { return nil, err } - dbCfg, err := loadDatabaseConfig() + dbCfg, err := loadDatabaseConfig(cliOpts) if err != nil { // No err, because the config values can be provided by the users. log.Println(err)