From 8300dd7d247090951da38dbaed3acf6e50862e52 Mon Sep 17 00:00:00 2001 From: vuong-nguyen <44292934+nkvuong@users.noreply.github.com> Date: Thu, 30 May 2024 21:20:35 +0100 Subject: [PATCH] fixed `Command.get_argument_type` bug with `UnionType` (#110) `Command.get_argument_type` currently crashes when `UnionType` is encountered. Add special handling for this type --- src/databricks/labs/blueprint/cli.py | 6 +++++- tests/unit/test_cli.py | 28 ++++++++++++++++++++++------ 2 files changed, 27 insertions(+), 7 deletions(-) diff --git a/src/databricks/labs/blueprint/cli.py b/src/databricks/labs/blueprint/cli.py index 2060374..34b7b5b 100644 --- a/src/databricks/labs/blueprint/cli.py +++ b/src/databricks/labs/blueprint/cli.py @@ -4,6 +4,7 @@ import inspect import json import logging +import types from collections.abc import Callable from dataclasses import dataclass @@ -40,7 +41,10 @@ def get_argument_type(self, argument_name: str) -> str | None: sig = inspect.signature(self.fn) if argument_name not in sig.parameters: return None - return sig.parameters[argument_name].annotation.__name__ + annotation = sig.parameters[argument_name].annotation + if isinstance(annotation, types.UnionType): + return str(annotation) + return annotation.__name__ class App: diff --git a/tests/unit/test_cli.py b/tests/unit/test_cli.py index a471fe0..9f3849e 100644 --- a/tests/unit/test_cli.py +++ b/tests/unit/test_cli.py @@ -16,6 +16,7 @@ "address": "", "is_customer": "true", "log_level": "disabled", + "optional_arg": "optional", }, } ) @@ -26,14 +27,21 @@ def test_commands(): app = App(inspect.getfile(App)) @app.command(is_unauthenticated=True) - def foo(name: str, age: int, salary: float, is_customer: bool, address: str = "default"): + def foo( + name: str, + age: int, + salary: float, + is_customer: bool, + address: str = "default", + optional_arg: str | None = None, + ): """Some comment""" - some(name, age, salary, is_customer, address) + some(name, age, salary, is_customer, address, optional_arg) with mock.patch.object(sys, "argv", [..., FOO_COMMAND]): app() - some.assert_called_with("y", 100, 100.5, True, "default") + some.assert_called_with("y", 100, 100.5, True, "default", "optional") def test_injects_prompts(): @@ -41,12 +49,20 @@ def test_injects_prompts(): app = App(inspect.getfile(App)) @app.command(is_unauthenticated=True) - def foo(name: str, age: int, salary: float, is_customer: bool, prompts: Prompts, address: str = "default"): + def foo( + name: str, + age: int, + salary: float, + is_customer: bool, + prompts: Prompts, + address: str = "default", + optional_arg: str | None = None, + ): """Some comment""" assert isinstance(prompts, Prompts) - some(name, age, salary, is_customer, address) + some(name, age, salary, is_customer, address, optional_arg) with mock.patch.object(sys, "argv", [..., FOO_COMMAND]): app() - some.assert_called_with("y", 100, 100.5, True, "default") + some.assert_called_with("y", 100, 100.5, True, "default", "optional")