diff --git a/experimental/apps-mcp/lib/middlewares/warehouse.go b/experimental/apps-mcp/lib/middlewares/warehouse.go new file mode 100644 index 0000000000..8fdaf9125f --- /dev/null +++ b/experimental/apps-mcp/lib/middlewares/warehouse.go @@ -0,0 +1,94 @@ +package middlewares + +import ( + "context" + "errors" + "fmt" + "net/url" + "sort" + + "github.com/databricks/cli/experimental/apps-mcp/lib/session" + "github.com/databricks/cli/libs/env" + "github.com/databricks/databricks-sdk-go/httpclient" + "github.com/databricks/databricks-sdk-go/service/sql" +) + +func GetWarehouseID(ctx context.Context) (string, error) { + sess, err := session.GetSession(ctx) + if err != nil { + return "", err + } + warehouseID, ok := sess.Get("warehouse_id") + if !ok { + warehouse, err := getDefaultWarehouse(ctx) + if err != nil { + return "", err + } + warehouseID = warehouse.Id + sess.Set("warehouse_id", warehouseID.(string)) + } + + return warehouseID.(string), nil +} + +func getDefaultWarehouse(ctx context.Context) (*sql.EndpointInfo, error) { + // first resolve DATABRICKS_WAREHOUSE_ID env variable + warehouseID := env.Get(ctx, "DATABRICKS_WAREHOUSE_ID") + if warehouseID != "" { + w := MustGetDatabricksClient(ctx) + warehouse, err := w.Warehouses.Get(ctx, sql.GetWarehouseRequest{ + Id: warehouseID, + }) + if err != nil { + return nil, fmt.Errorf("get warehouse: %w", err) + } + return &sql.EndpointInfo{ + Id: warehouse.Id, + }, nil + } + + apiClient, err := MustGetApiClient(ctx) + if err != nil { + return nil, err + } + + apiPath := "/api/2.0/sql/warehouses" + params := url.Values{} + params.Add("skip_cannot_use", "true") + fullPath := fmt.Sprintf("%s?%s", apiPath, params.Encode()) + + var response sql.ListWarehousesResponse + err = apiClient.Do(ctx, "GET", fullPath, httpclient.WithResponseUnmarshal(&response)) + if err != nil { + return nil, err + } + + priorities := map[sql.State]int{ + sql.StateRunning: 1, + sql.StateStarting: 2, + sql.StateStopped: 3, + sql.StateStopping: 4, + sql.StateDeleted: 99, + sql.StateDeleting: 99, + } + + warehouses := response.Warehouses + sort.Slice(warehouses, func(i, j int) bool { + return priorities[warehouses[i].State] < priorities[warehouses[j].State] + }) + + if len(warehouses) == 0 { + return nil, errNoWarehouses() + } + + firstWarehouse := warehouses[0] + if firstWarehouse.State == sql.StateDeleted || firstWarehouse.State == sql.StateDeleting { + return nil, errNoWarehouses() + } + + return &firstWarehouse, nil +} + +func errNoWarehouses() error { + return errors.New("no warehouse found. You can explicitly set the warehouse ID using the DATABRICKS_WAREHOUSE_ID environment variable") +} diff --git a/experimental/apps-mcp/lib/providers/databricks/databricks.go b/experimental/apps-mcp/lib/providers/databricks/databricks.go index 7724ad997f..e3bca7f902 100644 --- a/experimental/apps-mcp/lib/providers/databricks/databricks.go +++ b/experimental/apps-mcp/lib/providers/databricks/databricks.go @@ -435,9 +435,9 @@ type DatabricksRestClient struct { func NewDatabricksRestClient(ctx context.Context, cfg *mcp.Config) (*DatabricksRestClient, error) { client := middlewares.MustGetDatabricksClient(ctx) - warehouseID := os.Getenv("DATABRICKS_WAREHOUSE_ID") - if warehouseID == "" { - return nil, errors.New("DATABRICKS_WAREHOUSE_ID not configured") + warehouseID, err := middlewares.GetWarehouseID(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get warehouse ID: %w", err) } return &DatabricksRestClient{ diff --git a/experimental/apps-mcp/lib/providers/databricks/deployment.go b/experimental/apps-mcp/lib/providers/databricks/deployment.go index 756c336032..ac24d6120a 100644 --- a/experimental/apps-mcp/lib/providers/databricks/deployment.go +++ b/experimental/apps-mcp/lib/providers/databricks/deployment.go @@ -3,11 +3,11 @@ package databricks import ( "context" "fmt" - "os" "os/exec" "time" mcp "github.com/databricks/cli/experimental/apps-mcp/lib" + "github.com/databricks/cli/experimental/apps-mcp/lib/middlewares" "github.com/databricks/cli/libs/cmdctx" "github.com/databricks/databricks-sdk-go/service/apps" "github.com/databricks/databricks-sdk-go/service/iam" @@ -102,8 +102,11 @@ func DeployApp(ctx context.Context, cfg *mcp.Config, appInfo *apps.App) error { return nil } -func ResourcesFromEnv(cfg *mcp.Config) (*apps.AppResource, error) { - warehouseID := os.Getenv("DATABRICKS_WAREHOUSE_ID") +func ResourcesFromEnv(ctx context.Context, cfg *mcp.Config) (*apps.AppResource, error) { + warehouseID, err := middlewares.GetWarehouseID(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get warehouse ID: %w", err) + } return &apps.AppResource{ Name: "base", diff --git a/experimental/apps-mcp/lib/providers/deployment/provider.go b/experimental/apps-mcp/lib/providers/deployment/provider.go index 9950efdfc3..d9273e5f6e 100644 --- a/experimental/apps-mcp/lib/providers/deployment/provider.go +++ b/experimental/apps-mcp/lib/providers/deployment/provider.go @@ -289,7 +289,7 @@ func (p *Provider) getOrCreateApp(ctx context.Context, name, description string, log.Infof(ctx, "App not found, creating new app: name=%s", name) - resources, err := databricks.ResourcesFromEnv(p.config) + resources, err := databricks.ResourcesFromEnv(ctx, p.config) if err != nil { return nil, err }