diff --git a/src/databricks/labs/blueprint/cli.py b/src/databricks/labs/blueprint/cli.py index 34b7b5b..5d15b0b 100644 --- a/src/databricks/labs/blueprint/cli.py +++ b/src/databricks/labs/blueprint/cli.py @@ -21,6 +21,7 @@ class Command: description: str fn: Callable[..., None] is_account: bool = False + is_collection: bool = False is_unauthenticated: bool = False def needs_workspace_client(self): @@ -30,6 +31,20 @@ def needs_workspace_client(self): return False return True + def run_as_collection(self) -> bool: + # A Method can be run as standalone workspace cmd or as a collection. To mark a method as collection method + # we need to add is_collection flag to True + # In addition if the collection_workspace_id is passed then return True else return False + # if collection_workspace_id is passed, the cmd should be run under account client else + # as workspace client. + if not self.is_collection: + return False + sig = inspect.signature(self.fn) + for param in sig.parameters.values(): + if param.name == "collection_workspace_id": + return True + return False + def prompts_argument_name(self) -> str | None: sig = inspect.signature(self.fn) for param in sig.parameters.values(): @@ -53,7 +68,7 @@ def __init__(self, __file: str): self._logger = get_logger(__file) self._product_info = ProductInfo(__file) - def command(self, fn=None, is_account: bool = False, is_unauthenticated: bool = False): + def command(self, fn=None, is_account: bool = False, is_unauthenticated: bool = False, is_collection=False): """Decorator to register a function as a command.""" def register(func): @@ -66,6 +81,7 @@ def register(func): fn=func, is_account=is_account, is_unauthenticated=is_unauthenticated, + is_collection=is_collection, ) return func @@ -99,9 +115,13 @@ def _route(self, raw): case "float": kwargs[kwarg] = float(kwargs[kwarg]) try: - if cmd.needs_workspace_client(): + if cmd.needs_workspace_client() and not cmd.run_as_collection(): + # if is_account is not set and cmd is either not a collection or + # is a collection but collection_workspace_id not passed kwargs["w"] = self._workspace_client() - elif cmd.is_account: + elif cmd.is_account or cmd.run_as_collection(): + # if is_account is set or cmd is a collection + # and collection_workspace_id is passed kwargs["a"] = self._account_client() prompts_argument = cmd.prompts_argument_name() if prompts_argument: diff --git a/tests/unit/test_cli.py b/tests/unit/test_cli.py index 9f3849e..09120a6 100644 --- a/tests/unit/test_cli.py +++ b/tests/unit/test_cli.py @@ -2,6 +2,9 @@ import json import sys from unittest import mock +from unittest.mock import create_autospec + +from databricks.sdk import AccountClient, WorkspaceClient from databricks.labs.blueprint.cli import App from databricks.labs.blueprint.tui import Prompts @@ -66,3 +69,54 @@ def foo( app() some.assert_called_with("y", 100, 100.5, True, "default", "optional") + + +def test_collection_commands_account(mocker): + some = mock.Mock() + app = App(inspect.getfile(App)) + acc_client = create_autospec(AccountClient) + mocker.patch("databricks.sdk.AccountClient.__new__", mock.Mock(return_value=acc_client)) + + @app.command(is_unauthenticated=False, is_collection=True) + def foo( + name: str, + age: int, + salary: float, + is_customer: bool, + a: AccountClient, + collection_workspace_id: int = 1234, + address: str = "default", + optional_arg: str | None = None, + ): + """Some comment""" + some(name, age, salary, is_customer, collection_workspace_id, address, optional_arg, a) + + with mock.patch.object(sys, "argv", [..., FOO_COMMAND]): + app() + + some.assert_called_with("y", 100, 100.5, True, 1234, "default", "optional", acc_client) + + +def test_collection_commands_workspace(mocker): + some = mock.Mock() + app = App(inspect.getfile(App)) + ws = create_autospec(WorkspaceClient) + mocker.patch("databricks.sdk.WorkspaceClient.__new__", mock.Mock(return_value=ws)) + + @app.command(is_unauthenticated=False, is_collection=True) + def foo( + name: str, + age: int, + salary: float, + is_customer: bool, + w: WorkspaceClient, + address: str = "default", + optional_arg: str | None = None, + ): + """Some comment""" + some(name, age, salary, is_customer, address, optional_arg, w) + + with mock.patch.object(sys, "argv", [..., FOO_COMMAND]): + app() + + some.assert_called_with("y", 100, 100.5, True, "default", "optional", ws)