Skip to content

Commit ced4c41

Browse files
committed
Add lazy warehouse resolution with smart discovery
Implements intelligent warehouse auto-discovery: - Add GetWarehouseID() with smart fallback chain: 1. Check DATABRICKS_WAREHOUSE_ID environment variable 2. Query available warehouses via API 3. Prefer RUNNING warehouses 4. Fall back to STOPPED warehouses (auto-start) 5. Use first available warehouse - Cache warehouse ID in session after first resolution - Remove warehouse ID from CLI flags and config - Update NewDatabricksRestClient and ResourcesFromEnv to use GetWarehouseID Eliminates need to specify warehouse ID upfront.
1 parent 537c596 commit ced4c41

File tree

4 files changed

+97
-7
lines changed

4 files changed

+97
-7
lines changed
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
package middlewares
2+
3+
import (
4+
"context"
5+
"errors"
6+
"fmt"
7+
"net/url"
8+
9+
"github.com/databricks/cli/experimental/apps-mcp/lib/session"
10+
"github.com/databricks/cli/libs/env"
11+
"github.com/databricks/databricks-sdk-go/httpclient"
12+
"github.com/databricks/databricks-sdk-go/service/sql"
13+
)
14+
15+
func GetWarehouseID(ctx context.Context) (string, error) {
16+
sess, err := session.GetSession(ctx)
17+
if err != nil {
18+
return "", err
19+
}
20+
warehouseID, ok := sess.Get("warehouse_id")
21+
if !ok {
22+
warehouse, err := getDefaultWarehouse(ctx)
23+
if err != nil {
24+
return "", err
25+
}
26+
warehouseID = warehouse.Id
27+
sess.Set("warehouse_id", warehouseID.(string))
28+
}
29+
30+
return warehouseID.(string), nil
31+
}
32+
33+
func getDefaultWarehouse(ctx context.Context) (*sql.EndpointInfo, error) {
34+
// first resolve DATABRICKS_WAREHOUSE_ID env variable
35+
warehouseID := env.Get(ctx, "DATABRICKS_WAREHOUSE_ID")
36+
if warehouseID != "" {
37+
w := MustGetDatabricksClient(ctx)
38+
warehouse, err := w.Warehouses.Get(ctx, sql.GetWarehouseRequest{
39+
Id: warehouseID,
40+
})
41+
if err != nil {
42+
return nil, fmt.Errorf("get warehouse: %w", err)
43+
}
44+
return &sql.EndpointInfo{
45+
Id: warehouse.Id,
46+
}, nil
47+
}
48+
49+
apiClient, err := MustGetApiClient(ctx)
50+
if err != nil {
51+
return nil, err
52+
}
53+
54+
apiPath := "/api/2.0/sql/warehouses"
55+
params := url.Values{}
56+
params.Add("skip_cannot_use", "true")
57+
fullPath := fmt.Sprintf("%s?%s", apiPath, params.Encode())
58+
59+
var response sql.ListWarehousesResponse
60+
err = apiClient.Do(ctx, "GET", fullPath, httpclient.WithResponseUnmarshal(&response))
61+
if err != nil {
62+
return nil, err
63+
}
64+
65+
warehouses := response.Warehouses
66+
67+
if len(warehouses) == 0 {
68+
return nil, errors.New("no warehouses found")
69+
}
70+
71+
// Prefer RUNNING warehouses
72+
for i := range warehouses {
73+
if warehouses[i].State == sql.StateRunning {
74+
return &warehouses[i], nil
75+
}
76+
}
77+
78+
// Fall back to STOPPED warehouses (they auto-start when queried)
79+
for i := range warehouses {
80+
if warehouses[i].State == sql.StateStopped {
81+
return &warehouses[i], nil
82+
}
83+
}
84+
85+
// Return first available warehouse regardless of state
86+
return &warehouses[0], nil
87+
}

experimental/apps-mcp/lib/providers/databricks/databricks.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -435,9 +435,9 @@ type DatabricksRestClient struct {
435435
func NewDatabricksRestClient(ctx context.Context, cfg *mcp.Config) (*DatabricksRestClient, error) {
436436
client := middlewares.MustGetDatabricksClient(ctx)
437437

438-
warehouseID := os.Getenv("DATABRICKS_WAREHOUSE_ID")
439-
if warehouseID == "" {
440-
return nil, errors.New("DATABRICKS_WAREHOUSE_ID not configured")
438+
warehouseID, err := middlewares.GetWarehouseID(ctx)
439+
if err != nil {
440+
return nil, fmt.Errorf("failed to get warehouse ID: %w", err)
441441
}
442442

443443
return &DatabricksRestClient{

experimental/apps-mcp/lib/providers/databricks/deployment.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@ package databricks
33
import (
44
"context"
55
"fmt"
6-
"os"
76
"os/exec"
87
"time"
98

109
mcp "github.com/databricks/cli/experimental/apps-mcp/lib"
10+
"github.com/databricks/cli/experimental/apps-mcp/lib/middlewares"
1111
"github.com/databricks/cli/libs/cmdctx"
1212
"github.com/databricks/databricks-sdk-go/service/apps"
1313
"github.com/databricks/databricks-sdk-go/service/iam"
@@ -102,8 +102,11 @@ func DeployApp(ctx context.Context, cfg *mcp.Config, appInfo *apps.App) error {
102102
return nil
103103
}
104104

105-
func ResourcesFromEnv(cfg *mcp.Config) (*apps.AppResource, error) {
106-
warehouseID := os.Getenv("DATABRICKS_WAREHOUSE_ID")
105+
func ResourcesFromEnv(ctx context.Context, cfg *mcp.Config) (*apps.AppResource, error) {
106+
warehouseID, err := middlewares.GetWarehouseID(ctx)
107+
if err != nil {
108+
return nil, fmt.Errorf("failed to get warehouse ID: %w", err)
109+
}
107110

108111
return &apps.AppResource{
109112
Name: "base",

experimental/apps-mcp/lib/providers/deployment/provider.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ func (p *Provider) getOrCreateApp(ctx context.Context, name, description string,
289289

290290
log.Infof(ctx, "App not found, creating new app: name=%s", name)
291291

292-
resources, err := databricks.ResourcesFromEnv(p.config)
292+
resources, err := databricks.ResourcesFromEnv(ctx, p.config)
293293
if err != nil {
294294
return nil, err
295295
}

0 commit comments

Comments
 (0)