diff --git a/xata/utils.go b/xata/utils.go index a4c5796..80db0d5 100644 --- a/xata/utils.go +++ b/xata/utils.go @@ -184,6 +184,15 @@ func getBranchName() string { return defaultBranchName } +// Get the region if the corresponding env var `XATA_REGION` is set +// otherwise return the default region: us-east-1 +func getRegion() string { + if region, found := os.LookupEnv("XATA_REGION"); found { + return region + } + return defaultRegion +} + // loadDatabaseConfig will return config with defaults if the error is not nil. func loadDatabaseConfig() (databaseConfig, error) { defaultDBConfig := databaseConfig{ @@ -191,6 +200,19 @@ func loadDatabaseConfig() (databaseConfig, error) { branchName: defaultBranchName, domainWorkspace: defaultDataPlaneDomain, } + // Setup with env var + // XATA_WORKSPACE_ID to set the workspace Id + if wsID, found := os.LookupEnv("XATA_WORKSPACE_ID"); found { + region := getRegion() + branch := getBranchName() + db := databaseConfig{ + workspaceID: wsID, + region: region, + branchName: branch, + } + return db, nil + } + cfg, err := loadConfig(configFileName) if err != nil { return defaultDBConfig, err diff --git a/xata/utils_test.go b/xata/utils_test.go index 8734022..a4d6ea1 100644 --- a/xata/utils_test.go +++ b/xata/utils_test.go @@ -26,6 +26,50 @@ func TestClientOptions_getAPIKey(t *testing.T) { }) } +func Test_getBranchName(t *testing.T) { + // default state + t.Run("should be default branch name", func(t *testing.T) { + gotBranchName := getBranchName() + assert.Equal(t, gotBranchName, defaultBranchName) + }) + + setBranchName := "feature-042" + err := os.Setenv(branchNameEnvVar, setBranchName) + if err != nil { + t.Fatal(err) + } + + // from env var + t.Run("should be branch name from env var", func(t *testing.T) { + gotBranchName := getBranchName() + assert.Equal(t, gotBranchName, setBranchName) + }) + + t.Cleanup(func() { os.Unsetenv(branchNameEnvVar) }) +} + +func Test_getRegion(t *testing.T) { + // default state + t.Run("should be default region", func(t *testing.T) { + gotRegion := getRegion() + assert.Equal(t, gotRegion, defaultRegion) + }) + + setRegion := "eu-west-3" + err := os.Setenv("XATA_REGION", setRegion) + if err != nil { + t.Fatal(err) + } + + // from env var + t.Run("should be region from the env var", func(t *testing.T) { + gotRegion := getRegion() + assert.Equal(t, gotRegion, setRegion) + }) + + t.Cleanup(func() { os.Unsetenv("XATA_REGION") }) +} + func Test_parseDatabaseURL(t *testing.T) { tests := []struct { name string @@ -95,6 +139,7 @@ func Test_parseDatabaseURL(t *testing.T) { } func Test_loadConfig(t *testing.T) { + // from .xatarc t.Run("should read database URL", func(t *testing.T) { // Create a temporary JSON file for testing tempFile, err := os.CreateTemp("", "config_test.json") @@ -128,3 +173,64 @@ func Test_loadConfig(t *testing.T) { } }) } + +func Test_loadDatabaseConfig_with_envvars(t *testing.T) { + setWsId := "workspace-0lac00" + err := os.Setenv("XATA_WORKSPACE_ID", setWsId) + if err != nil { + t.Fatal(err) + } + + // test workspace id from env var + t.Run("load config from WORKSPACE_ID env var", func(t *testing.T) { + dbCfg, err := loadDatabaseConfig() + if err != nil { + t.Fatalf("Error loading config: %v", err) + } + + if dbCfg.workspaceID != setWsId { + t.Fatalf("Expected Workspace ID: %s, got: %s", setWsId, dbCfg.workspaceID) + } + if dbCfg.branchName != defaultBranchName { + t.Fatalf("Expected branch name: %s, got: %s", defaultBranchName, dbCfg.branchName) + } + if dbCfg.region != defaultRegion { + t.Fatalf("Expected region: %s, got: %s", defaultRegion, dbCfg.region) + } + }) + + setBranch := "branch123" + err2 := os.Setenv(branchNameEnvVar, setBranch) + if err2 != nil { + t.Fatal(err2) + } + setRegion := "ap-southeast-16" + err3 := os.Setenv("XATA_REGION", setRegion) + if err3 != nil { + t.Fatal(err3) + } + + // with branch and region env vars + t.Run("load config from XATA_WORKSPACE_ID, XATA_REGION and XATA_BRANCH env vars", func(t *testing.T) { + dbCfg, err := loadDatabaseConfig() + if err != nil { + t.Fatalf("Error loading config: %v", err) + } + + if dbCfg.workspaceID != setWsId { + t.Fatalf("Expected Workspace ID: %s, got: %s", setWsId, dbCfg.workspaceID) + } + if dbCfg.branchName != setBranch { + t.Fatalf("Expected branch name: %s, got: %s", setBranch, dbCfg.branchName) + } + if dbCfg.region != setRegion { + t.Fatalf("Expected region: %s, got: %s", setRegion, dbCfg.region) + } + }) + + t.Cleanup(func() { + os.Unsetenv("XATA_WORKSPACE_ID") + os.Unsetenv("XATA_BRANCH") + os.Unsetenv("XATA_REGION") + }) +}