diff --git a/src/databricks/labs/blueprint/cli.py b/src/databricks/labs/blueprint/cli.py index 726022b..2060374 100644 --- a/src/databricks/labs/blueprint/cli.py +++ b/src/databricks/labs/blueprint/cli.py @@ -36,6 +36,12 @@ def prompts_argument_name(self) -> str | None: return param.name return None + def get_argument_type(self, argument_name: str) -> str | None: + sig = inspect.signature(self.fn) + if argument_name not in sig.parameters: + return None + return sig.parameters[argument_name].annotation.__name__ + class App: def __init__(self, __file: str): @@ -77,9 +83,18 @@ def _route(self, raw): log_level = "info" databricks_logger = logging.getLogger("databricks") databricks_logger.setLevel(log_level.upper()) - kwargs = {k.replace("-", "_"): v for k, v in flags.items()} + kwargs = {k.replace("-", "_"): v for k, v in flags.items() if v != ""} + cmd = self._mapping[command] + # modify kwargs to match the type of the argument + for kwarg in list(kwargs.keys()): + match cmd.get_argument_type(kwarg): + case "int": + kwargs[kwarg] = int(kwargs[kwarg]) + case "bool": + kwargs[kwarg] = kwargs[kwarg].lower() == "true" + case "float": + kwargs[kwarg] = float(kwargs[kwarg]) try: - cmd = self._mapping[command] if cmd.needs_workspace_client(): kwargs["w"] = self._workspace_client() elif cmd.is_account: diff --git a/tests/unit/test_cli.py b/tests/unit/test_cli.py index 89e822d..a471fe0 100644 --- a/tests/unit/test_cli.py +++ b/tests/unit/test_cli.py @@ -11,6 +11,10 @@ "command": "foo", "flags": { "name": "y", + "age": "100", + "salary": "100.5", + "address": "", + "is_customer": "true", "log_level": "disabled", }, } @@ -22,14 +26,14 @@ def test_commands(): app = App(inspect.getfile(App)) @app.command(is_unauthenticated=True) - def foo(name: str): + def foo(name: str, age: int, salary: float, is_customer: bool, address: str = "default"): """Some comment""" - some(name) + some(name, age, salary, is_customer, address) with mock.patch.object(sys, "argv", [..., FOO_COMMAND]): app() - some.assert_called_with("y") + some.assert_called_with("y", 100, 100.5, True, "default") def test_injects_prompts(): @@ -37,12 +41,12 @@ def test_injects_prompts(): app = App(inspect.getfile(App)) @app.command(is_unauthenticated=True) - def foo(name: str, prompts: Prompts): + def foo(name: str, age: int, salary: float, is_customer: bool, prompts: Prompts, address: str = "default"): """Some comment""" assert isinstance(prompts, Prompts) - some(name) + some(name, age, salary, is_customer, address) with mock.patch.object(sys, "argv", [..., FOO_COMMAND]): app() - some.assert_called_with("y") + some.assert_called_with("y", 100, 100.5, True, "default")