From 0296eaa0d88c4b17889db9ca607d1a3f9a1a1f7c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Patrick=20D=C3=BCggelin?= Date: Wed, 24 Apr 2024 11:19:53 +0200 Subject: [PATCH 01/15] :pencil: Add basic pydantic example to docs --- docs_src/parameter_types/pydantic/__init__.py | 0 docs_src/parameter_types/pydantic/tutorial001.py | 16 ++++++++++++++++ 2 files changed, 16 insertions(+) create mode 100644 docs_src/parameter_types/pydantic/__init__.py create mode 100644 docs_src/parameter_types/pydantic/tutorial001.py diff --git a/docs_src/parameter_types/pydantic/__init__.py b/docs_src/parameter_types/pydantic/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/docs_src/parameter_types/pydantic/tutorial001.py b/docs_src/parameter_types/pydantic/tutorial001.py new file mode 100644 index 0000000000..8e233ebd98 --- /dev/null +++ b/docs_src/parameter_types/pydantic/tutorial001.py @@ -0,0 +1,16 @@ +import pydantic +import typer + + +class User(pydantic.BaseModel): + id: int + name: str = "Jane Doe" + + +def main(num: int, user: User): + print(num, type(num)) + print(user, type(user)) + + +if __name__ == "__main__": + typer.run(main) From 46ac20fd3c5c63b42ab20b2cf553ab4127583a98 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Patrick=20D=C3=BCggelin?= Date: Wed, 24 Apr 2024 11:20:08 +0200 Subject: [PATCH 02/15] :white_check_mark: Add basic testcase --- .../test_pydantic/__init__.py | 0 .../test_pydantic/test_tutorial001.py | 33 +++++++++++++++++++ 2 files changed, 33 insertions(+) create mode 100644 tests/test_tutorial/test_parameter_types/test_pydantic/__init__.py create mode 100644 tests/test_tutorial/test_parameter_types/test_pydantic/test_tutorial001.py diff --git a/tests/test_tutorial/test_parameter_types/test_pydantic/__init__.py b/tests/test_tutorial/test_parameter_types/test_pydantic/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_tutorial/test_parameter_types/test_pydantic/test_tutorial001.py b/tests/test_tutorial/test_parameter_types/test_pydantic/test_tutorial001.py new file mode 100644 index 0000000000..256bdd67ba --- /dev/null +++ b/tests/test_tutorial/test_parameter_types/test_pydantic/test_tutorial001.py @@ -0,0 +1,33 @@ +import subprocess +import sys + +import typer +from typer.testing import CliRunner + +from docs_src.parameter_types.pydantic import tutorial001 as mod + +runner = CliRunner() + +app = typer.Typer() +app.command()(mod.main) + + +def test_help(): + result = runner.invoke(app, ["--help"]) + assert result.exit_code == 0 + + + +def test_parse_pydantic_model(): + result = runner.invoke(app, ["1", "--user-id", "2", "--user-name", "John Doe"]) + assert "1 " in result.output + assert "id=2 name='John Doe' " in result.output + + +def test_script(): + result = subprocess.run( + [sys.executable, "-m", "coverage", "run", mod.__file__, "--help"], + capture_output=True, + encoding="utf-8", + ) + assert "Usage" in result.stdout From 5c1ee1d73ed8731194d0daad26fa620deaa73597 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Patrick=20D=C3=BCggelin?= Date: Wed, 24 Apr 2024 11:20:41 +0200 Subject: [PATCH 03/15] :sparkles: Basic implementation --- typer/main.py | 7 ++-- typer/pydantic_extension.py | 66 +++++++++++++++++++++++++++++++++++++ typer/utils.py | 7 ++-- 3 files changed, 76 insertions(+), 4 deletions(-) create mode 100644 typer/pydantic_extension.py diff --git a/typer/main.py b/typer/main.py index 9db26975ca..e3ca764f93 100644 --- a/typer/main.py +++ b/typer/main.py @@ -13,6 +13,8 @@ import click +from typer.pydantic_extension import wrap_pydantic_callback + from .completion import get_completion_inspect_parameters from .core import MarkupMode, TyperArgument, TyperCommand, TyperGroup, TyperOption from .models import ( @@ -572,17 +574,18 @@ def get_command_from_info( use_help = inspect.getdoc(command_info.callback) else: use_help = inspect.cleandoc(use_help) + callback = wrap_pydantic_callback(command_info.callback) ( params, convertors, context_param_name, - ) = get_params_convertors_ctx_param_name_from_function(command_info.callback) + ) = get_params_convertors_ctx_param_name_from_function(callback) cls = command_info.cls or TyperCommand command = cls( name=name, context_settings=command_info.context_settings, callback=get_callback( - callback=command_info.callback, + callback=callback, params=params, convertors=convertors, context_param_name=context_param_name, diff --git a/typer/pydantic_extension.py b/typer/pydantic_extension.py new file mode 100644 index 0000000000..eb236f0a7f --- /dev/null +++ b/typer/pydantic_extension.py @@ -0,0 +1,66 @@ +import copy +import inspect +from typing import Annotated, Callable + +from pydantic_core import PydanticUndefined +from pydantic._internal._utils import deep_update + +from .params import Option +from .utils import inspect_signature + +import pydantic + +def flatten_pydantic_model(model: pydantic.BaseModel, ancestors: list[str]) -> dict: + from .main import lenient_issubclass + pydantic_parameters = {} + for field_name, field in model.model_fields.items(): + qualifier = [*ancestors, field_name] + sub_name = f"{'_'.join(qualifier)}" + if lenient_issubclass(field.annotation, pydantic.BaseModel): + pydantic_parameters.update(flatten_pydantic_model(field.annotation, qualifier)) + else: + pydantic_parameters[sub_name] = inspect.Parameter( + sub_name, + inspect.Parameter.KEYWORD_ONLY, + annotation=Annotated[field.annotation, Option(), qualifier], + default=field.default if field.default != PydanticUndefined else inspect.Parameter.empty, + ) + return pydantic_parameters + + +def wrap_pydantic_callback(callback: Callable) -> Callable: + from .main import lenient_issubclass + original_signature = inspect_signature(callback) + + pydantic_parameters = {} + pydantic_roots = {} + other_parameters = {} + for name, parameter in original_signature.parameters.items(): + if lenient_issubclass(parameter.annotation, pydantic.BaseModel): + pydantic_parameters.update(flatten_pydantic_model(parameter.annotation, [name])) + pydantic_roots[name] = parameter.annotation + else: + other_parameters[name] = parameter + + extended_signature = inspect.Signature( + [*other_parameters.values(), *pydantic_parameters.values(),], + return_annotation=original_signature.return_annotation, + ) + + def wrapper(*args, **kwargs): + converted_kwargs = kwargs.copy() + pydantic_dicts = {} + for kwarg_name, kwarg_value in kwargs.items(): + if kwarg_name in pydantic_parameters: + converted_kwargs.pop(kwarg_name) + annotation: Annotated = pydantic_parameters[kwarg_name].annotation + _, qualifier = annotation.__metadata__ + for part in reversed(qualifier): + kwarg_value = {part: kwarg_value} + pydantic_dicts = deep_update(pydantic_dicts, kwarg_value) + for root_name, value in pydantic_dicts.items(): + converted_kwargs[root_name] = pydantic_roots[root_name](**value) + return callback(*args, **converted_kwargs) + + wrapper.__signature__ = extended_signature + return wrapper diff --git a/typer/utils.py b/typer/utils.py index 2ba7bace45..687d5a4805 100644 --- a/typer/utils.py +++ b/typer/utils.py @@ -105,12 +105,15 @@ def _split_annotation_from_typer_annotations( if isinstance(annotation, ParameterInfo) ] - -def get_params_from_function(func: Callable[..., Any]) -> Dict[str, ParamMeta]: +def inspect_signature(func: Callable) -> inspect.Signature: if sys.version_info >= (3, 10): signature = inspect.signature(func, eval_str=True) else: signature = inspect.signature(func) + return signature + +def get_params_from_function(func: Callable[..., Any]) -> Dict[str, ParamMeta]: + signature = inspect_signature(func) type_hints = get_type_hints(func) params = {} From e91521473fe4ef7532b75d94edc668eebdc9f38d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Patrick=20D=C3=BCggelin?= Date: Wed, 24 Apr 2024 11:48:46 +0200 Subject: [PATCH 04/15] :sparkles: Change default field separator to . --- .../test_pydantic/test_tutorial001.py | 2 +- typer/pydantic_extension.py | 15 +++++++++------ 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/tests/test_tutorial/test_parameter_types/test_pydantic/test_tutorial001.py b/tests/test_tutorial/test_parameter_types/test_pydantic/test_tutorial001.py index 256bdd67ba..0ef2548229 100644 --- a/tests/test_tutorial/test_parameter_types/test_pydantic/test_tutorial001.py +++ b/tests/test_tutorial/test_parameter_types/test_pydantic/test_tutorial001.py @@ -19,7 +19,7 @@ def test_help(): def test_parse_pydantic_model(): - result = runner.invoke(app, ["1", "--user-id", "2", "--user-name", "John Doe"]) + result = runner.invoke(app, ["1", "--user.id", "2", "--user.name", "John Doe"]) assert "1 " in result.output assert "id=2 name='John Doe' " in result.output diff --git a/typer/pydantic_extension.py b/typer/pydantic_extension.py index eb236f0a7f..fb07d83225 100644 --- a/typer/pydantic_extension.py +++ b/typer/pydantic_extension.py @@ -1,4 +1,3 @@ -import copy import inspect from typing import Annotated, Callable @@ -10,20 +9,24 @@ import pydantic +PYDANTIC_FIELD_SEPARATOR = "." + def flatten_pydantic_model(model: pydantic.BaseModel, ancestors: list[str]) -> dict: from .main import lenient_issubclass pydantic_parameters = {} for field_name, field in model.model_fields.items(): qualifier = [*ancestors, field_name] - sub_name = f"{'_'.join(qualifier)}" + sub_name = f"_pydantic_{'_'.join(qualifier)}" if lenient_issubclass(field.annotation, pydantic.BaseModel): pydantic_parameters.update(flatten_pydantic_model(field.annotation, qualifier)) else: + default = field.default if field.default != PydanticUndefined else ... + typer_option = Option(f"--{PYDANTIC_FIELD_SEPARATOR.join(qualifier)}") pydantic_parameters[sub_name] = inspect.Parameter( sub_name, - inspect.Parameter.KEYWORD_ONLY, - annotation=Annotated[field.annotation, Option(), qualifier], - default=field.default if field.default != PydanticUndefined else inspect.Parameter.empty, + inspect.Parameter.KEYWORD_ONLY, + annotation=Annotated[field.annotation, typer_option, qualifier], + default=default, ) return pydantic_parameters @@ -43,7 +46,7 @@ def wrap_pydantic_callback(callback: Callable) -> Callable: other_parameters[name] = parameter extended_signature = inspect.Signature( - [*other_parameters.values(), *pydantic_parameters.values(),], + [*other_parameters.values(), *pydantic_parameters.values()], return_annotation=original_signature.return_annotation, ) From bfa4e6d9c18a2bbf60563ee5cd54a40927ca4ac3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 24 Apr 2024 10:28:52 +0000 Subject: [PATCH 05/15] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20f?= =?UTF-8?q?ormat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../parameter_types/pydantic/tutorial001.py | 3 +- .../test_pydantic/test_tutorial001.py | 6 ++-- typer/main.py | 1 - typer/pydantic_extension.py | 28 +++++++++++-------- typer/utils.py | 2 ++ 5 files changed, 25 insertions(+), 15 deletions(-) diff --git a/docs_src/parameter_types/pydantic/tutorial001.py b/docs_src/parameter_types/pydantic/tutorial001.py index 8e233ebd98..2265ec987c 100644 --- a/docs_src/parameter_types/pydantic/tutorial001.py +++ b/docs_src/parameter_types/pydantic/tutorial001.py @@ -1,6 +1,7 @@ -import pydantic import typer +import pydantic + class User(pydantic.BaseModel): id: int diff --git a/tests/test_tutorial/test_parameter_types/test_pydantic/test_tutorial001.py b/tests/test_tutorial/test_parameter_types/test_pydantic/test_tutorial001.py index 0ef2548229..1d0d81a60b 100644 --- a/tests/test_tutorial/test_parameter_types/test_pydantic/test_tutorial001.py +++ b/tests/test_tutorial/test_parameter_types/test_pydantic/test_tutorial001.py @@ -17,11 +17,13 @@ def test_help(): assert result.exit_code == 0 - def test_parse_pydantic_model(): result = runner.invoke(app, ["1", "--user.id", "2", "--user.name", "John Doe"]) assert "1 " in result.output - assert "id=2 name='John Doe' " in result.output + assert ( + "id=2 name='John Doe' " + in result.output + ) def test_script(): diff --git a/typer/main.py b/typer/main.py index e3ca764f93..f7e1173282 100644 --- a/typer/main.py +++ b/typer/main.py @@ -12,7 +12,6 @@ from uuid import UUID import click - from typer.pydantic_extension import wrap_pydantic_callback from .completion import get_completion_inspect_parameters diff --git a/typer/pydantic_extension.py b/typer/pydantic_extension.py index fb07d83225..c9bff2047e 100644 --- a/typer/pydantic_extension.py +++ b/typer/pydantic_extension.py @@ -1,31 +1,34 @@ import inspect from typing import Annotated, Callable -from pydantic_core import PydanticUndefined +import pydantic from pydantic._internal._utils import deep_update +from pydantic_core import PydanticUndefined from .params import Option from .utils import inspect_signature -import pydantic - PYDANTIC_FIELD_SEPARATOR = "." + def flatten_pydantic_model(model: pydantic.BaseModel, ancestors: list[str]) -> dict: from .main import lenient_issubclass + pydantic_parameters = {} for field_name, field in model.model_fields.items(): qualifier = [*ancestors, field_name] sub_name = f"_pydantic_{'_'.join(qualifier)}" if lenient_issubclass(field.annotation, pydantic.BaseModel): - pydantic_parameters.update(flatten_pydantic_model(field.annotation, qualifier)) + pydantic_parameters.update( + flatten_pydantic_model(field.annotation, qualifier) + ) else: default = field.default if field.default != PydanticUndefined else ... typer_option = Option(f"--{PYDANTIC_FIELD_SEPARATOR.join(qualifier)}") pydantic_parameters[sub_name] = inspect.Parameter( - sub_name, + sub_name, inspect.Parameter.KEYWORD_ONLY, - annotation=Annotated[field.annotation, typer_option, qualifier], + annotation=Annotated[field.annotation, typer_option, qualifier], default=default, ) return pydantic_parameters @@ -33,6 +36,7 @@ def flatten_pydantic_model(model: pydantic.BaseModel, ancestors: list[str]) -> d def wrap_pydantic_callback(callback: Callable) -> Callable: from .main import lenient_issubclass + original_signature = inspect_signature(callback) pydantic_parameters = {} @@ -40,14 +44,16 @@ def wrap_pydantic_callback(callback: Callable) -> Callable: other_parameters = {} for name, parameter in original_signature.parameters.items(): if lenient_issubclass(parameter.annotation, pydantic.BaseModel): - pydantic_parameters.update(flatten_pydantic_model(parameter.annotation, [name])) + pydantic_parameters.update( + flatten_pydantic_model(parameter.annotation, [name]) + ) pydantic_roots[name] = parameter.annotation else: - other_parameters[name] = parameter + other_parameters[name] = parameter extended_signature = inspect.Signature( - [*other_parameters.values(), *pydantic_parameters.values()], - return_annotation=original_signature.return_annotation, + [*other_parameters.values(), *pydantic_parameters.values()], + return_annotation=original_signature.return_annotation, ) def wrapper(*args, **kwargs): @@ -64,6 +70,6 @@ def wrapper(*args, **kwargs): for root_name, value in pydantic_dicts.items(): converted_kwargs[root_name] = pydantic_roots[root_name](**value) return callback(*args, **converted_kwargs) - + wrapper.__signature__ = extended_signature return wrapper diff --git a/typer/utils.py b/typer/utils.py index 687d5a4805..232a094fcf 100644 --- a/typer/utils.py +++ b/typer/utils.py @@ -105,6 +105,7 @@ def _split_annotation_from_typer_annotations( if isinstance(annotation, ParameterInfo) ] + def inspect_signature(func: Callable) -> inspect.Signature: if sys.version_info >= (3, 10): signature = inspect.signature(func, eval_str=True) @@ -112,6 +113,7 @@ def inspect_signature(func: Callable) -> inspect.Signature: signature = inspect.signature(func) return signature + def get_params_from_function(func: Callable[..., Any]) -> Dict[str, ParamMeta]: signature = inspect_signature(func) From 226565a283bcfc2b63568ff825a8cb31eb74049b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Patrick=20D=C3=BCggelin?= Date: Wed, 24 Apr 2024 14:41:16 +0200 Subject: [PATCH 06/15] :art: Lint with mypy --- typer/main.py | 2 +- typer/pydantic_extension.py | 26 +++++++++++++------------- typer/utils.py | 2 +- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/typer/main.py b/typer/main.py index f7e1173282..07c7658edf 100644 --- a/typer/main.py +++ b/typer/main.py @@ -12,7 +12,6 @@ from uuid import UUID import click -from typer.pydantic_extension import wrap_pydantic_callback from .completion import get_completion_inspect_parameters from .core import MarkupMode, TyperArgument, TyperCommand, TyperGroup, TyperOption @@ -35,6 +34,7 @@ Required, TyperInfo, ) +from .pydantic_extension import wrap_pydantic_callback from .utils import get_params_from_function try: diff --git a/typer/pydantic_extension.py b/typer/pydantic_extension.py index c9bff2047e..45ab721501 100644 --- a/typer/pydantic_extension.py +++ b/typer/pydantic_extension.py @@ -1,5 +1,5 @@ import inspect -from typing import Annotated, Callable +from typing import Annotated, Any, Callable import pydantic from pydantic._internal._utils import deep_update @@ -11,7 +11,9 @@ PYDANTIC_FIELD_SEPARATOR = "." -def flatten_pydantic_model(model: pydantic.BaseModel, ancestors: list[str]) -> dict: +def flatten_pydantic_model( + model: pydantic.BaseModel, ancestors: list[str] +) -> dict[str, inspect.Parameter]: from .main import lenient_issubclass pydantic_parameters = {} @@ -19,9 +21,8 @@ def flatten_pydantic_model(model: pydantic.BaseModel, ancestors: list[str]) -> d qualifier = [*ancestors, field_name] sub_name = f"_pydantic_{'_'.join(qualifier)}" if lenient_issubclass(field.annotation, pydantic.BaseModel): - pydantic_parameters.update( - flatten_pydantic_model(field.annotation, qualifier) - ) + params = flatten_pydantic_model(field.annotation, qualifier) # type: ignore[arg-type] + pydantic_parameters.update(params) else: default = field.default if field.default != PydanticUndefined else ... typer_option = Option(f"--{PYDANTIC_FIELD_SEPARATOR.join(qualifier)}") @@ -34,7 +35,7 @@ def flatten_pydantic_model(model: pydantic.BaseModel, ancestors: list[str]) -> d return pydantic_parameters -def wrap_pydantic_callback(callback: Callable) -> Callable: +def wrap_pydantic_callback(callback: Callable[..., Any]) -> Callable[..., Any]: from .main import lenient_issubclass original_signature = inspect_signature(callback) @@ -44,9 +45,8 @@ def wrap_pydantic_callback(callback: Callable) -> Callable: other_parameters = {} for name, parameter in original_signature.parameters.items(): if lenient_issubclass(parameter.annotation, pydantic.BaseModel): - pydantic_parameters.update( - flatten_pydantic_model(parameter.annotation, [name]) - ) + params = flatten_pydantic_model(parameter.annotation, [name]) + pydantic_parameters.update(params) pydantic_roots[name] = parameter.annotation else: other_parameters[name] = parameter @@ -56,13 +56,13 @@ def wrap_pydantic_callback(callback: Callable) -> Callable: return_annotation=original_signature.return_annotation, ) - def wrapper(*args, **kwargs): + def wrapper(*args, **kwargs): # type: ignore[no-untyped-def] converted_kwargs = kwargs.copy() - pydantic_dicts = {} + pydantic_dicts: dict[str, Any] = {} for kwarg_name, kwarg_value in kwargs.items(): if kwarg_name in pydantic_parameters: converted_kwargs.pop(kwarg_name) - annotation: Annotated = pydantic_parameters[kwarg_name].annotation + annotation = pydantic_parameters[kwarg_name].annotation _, qualifier = annotation.__metadata__ for part in reversed(qualifier): kwarg_value = {part: kwarg_value} @@ -71,5 +71,5 @@ def wrapper(*args, **kwargs): converted_kwargs[root_name] = pydantic_roots[root_name](**value) return callback(*args, **converted_kwargs) - wrapper.__signature__ = extended_signature + wrapper.__signature__ = extended_signature # type: ignore return wrapper diff --git a/typer/utils.py b/typer/utils.py index 232a094fcf..eaf3cebbf0 100644 --- a/typer/utils.py +++ b/typer/utils.py @@ -106,7 +106,7 @@ def _split_annotation_from_typer_annotations( ] -def inspect_signature(func: Callable) -> inspect.Signature: +def inspect_signature(func: Callable[..., Any]) -> inspect.Signature: if sys.version_info >= (3, 10): signature = inspect.signature(func, eval_str=True) else: From 05caa5a85f94814788b2c55c4d11d8a9249f739e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Patrick=20D=C3=BCggelin?= Date: Wed, 24 Apr 2024 20:59:53 +0200 Subject: [PATCH 07/15] :wrench: Add optional pydantic dependency, improve readability --- pyproject.toml | 1 + typer/main.py | 9 +------- typer/pydantic_extension.py | 38 ++++++++++++++++++-------------- typer/utils.py | 43 +++++++++++++++++++++++++++++++++++-- 4 files changed, 65 insertions(+), 26 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c9e793e1ed..1f1255559c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,7 @@ homepage = "https://github.com/tiangolo/typer" standard = [ "shellingham >=1.3.0", "rich >=10.11.0", + "pydantic >= 2.0.0", ] [tool.pdm] diff --git a/typer/main.py b/typer/main.py index 07c7658edf..9c35cfdc79 100644 --- a/typer/main.py +++ b/typer/main.py @@ -16,7 +16,6 @@ from .completion import get_completion_inspect_parameters from .core import MarkupMode, TyperArgument, TyperCommand, TyperGroup, TyperOption from .models import ( - AnyType, ArgumentInfo, CommandFunctionType, CommandInfo, @@ -35,7 +34,7 @@ TyperInfo, ) from .pydantic_extension import wrap_pydantic_callback -from .utils import get_params_from_function +from .utils import get_params_from_function, lenient_issubclass try: import rich @@ -790,12 +789,6 @@ def get_click_type( raise RuntimeError(f"Type not yet supported: {annotation}") # pragma: no cover -def lenient_issubclass( - cls: Any, class_or_tuple: Union[AnyType, Tuple[AnyType, ...]] -) -> bool: - return isinstance(cls, type) and issubclass(cls, class_or_tuple) - - def get_click_param( param: ParamMeta, ) -> Tuple[Union[click.Argument, click.Option], Any]: diff --git a/typer/pydantic_extension.py b/typer/pydantic_extension.py index 45ab721501..903b34249d 100644 --- a/typer/pydantic_extension.py +++ b/typer/pydantic_extension.py @@ -1,30 +1,33 @@ import inspect -from typing import Annotated, Any, Callable - -import pydantic -from pydantic._internal._utils import deep_update -from pydantic_core import PydanticUndefined +from typing import Annotated, Any, Callable, Dict, List from .params import Option -from .utils import inspect_signature +from .utils import deep_update, inspect_signature, lenient_issubclass + +try: + import pydantic +except ImportError: + pydantic = None # type: ignore PYDANTIC_FIELD_SEPARATOR = "." def flatten_pydantic_model( - model: pydantic.BaseModel, ancestors: list[str] -) -> dict[str, inspect.Parameter]: - from .main import lenient_issubclass - + model: "pydantic.BaseModel", ancestors: List[str] +) -> Dict[str, inspect.Parameter]: + if pydantic is None: + raise ImportError("Pydantic is required to use Pydantic models with Typer.") pydantic_parameters = {} for field_name, field in model.model_fields.items(): qualifier = [*ancestors, field_name] sub_name = f"_pydantic_{'_'.join(qualifier)}" if lenient_issubclass(field.annotation, pydantic.BaseModel): - params = flatten_pydantic_model(field.annotation, qualifier) # type: ignore[arg-type] + params = flatten_pydantic_model(field.annotation, qualifier) # type: ignore pydantic_parameters.update(params) else: - default = field.default if field.default != PydanticUndefined else ... + default = ( + field.default if field.default is not pydantic.fields._Unset else ... + ) typer_option = Option(f"--{PYDANTIC_FIELD_SEPARATOR.join(qualifier)}") pydantic_parameters[sub_name] = inspect.Parameter( sub_name, @@ -36,7 +39,8 @@ def flatten_pydantic_model( def wrap_pydantic_callback(callback: Callable[..., Any]) -> Callable[..., Any]: - from .main import lenient_issubclass + if pydantic is None: + return callback original_signature = inspect_signature(callback) @@ -58,7 +62,7 @@ def wrap_pydantic_callback(callback: Callable[..., Any]) -> Callable[..., Any]: def wrapper(*args, **kwargs): # type: ignore[no-untyped-def] converted_kwargs = kwargs.copy() - pydantic_dicts: dict[str, Any] = {} + raw_pydantic_objects: Dict[str, Any] = {} for kwarg_name, kwarg_value in kwargs.items(): if kwarg_name in pydantic_parameters: converted_kwargs.pop(kwarg_name) @@ -66,8 +70,10 @@ def wrapper(*args, **kwargs): # type: ignore[no-untyped-def] _, qualifier = annotation.__metadata__ for part in reversed(qualifier): kwarg_value = {part: kwarg_value} - pydantic_dicts = deep_update(pydantic_dicts, kwarg_value) - for root_name, value in pydantic_dicts.items(): + raw_pydantic_objects = deep_update( + raw_pydantic_objects, kwarg_value + ) + for root_name, value in raw_pydantic_objects.items(): converted_kwargs[root_name] = pydantic_roots[root_name](**value) return callback(*args, **converted_kwargs) diff --git a/typer/utils.py b/typer/utils.py index eaf3cebbf0..325ba588fc 100644 --- a/typer/utils.py +++ b/typer/utils.py @@ -1,12 +1,23 @@ import inspect import sys from copy import copy -from typing import Any, Callable, Dict, List, Tuple, Type, cast, get_type_hints +from typing import ( + Any, + Callable, + Dict, + List, + Tuple, + Type, + TypeVar, + Union, + cast, + get_type_hints, +) from typing_extensions import Annotated from ._typing import get_args, get_origin -from .models import ArgumentInfo, OptionInfo, ParameterInfo, ParamMeta +from .models import AnyType, ArgumentInfo, OptionInfo, ParameterInfo, ParamMeta def _param_type_to_user_string(param_type: Type[ParameterInfo]) -> str: @@ -106,6 +117,34 @@ def _split_annotation_from_typer_annotations( ] +def lenient_issubclass( + cls: Any, class_or_tuple: Union[AnyType, Tuple[AnyType, ...]] +) -> bool: + return isinstance(cls, type) and issubclass(cls, class_or_tuple) + + +KeyType = TypeVar("KeyType") + + +def deep_update( + mapping: dict[KeyType, Any], *updating_mappings: dict[KeyType, Any] +) -> dict[KeyType, Any]: + # Copied from pydantic because they don't expose it publicly: + # https://github.com/pydantic/pydantic/blob/26129479a06960af9d02d3a948e51985fe59ed4b/pydantic/_internal/_utils.py#L103 + updated_mapping = mapping.copy() + for updating_mapping in updating_mappings: + for k, v in updating_mapping.items(): + if ( + k in updated_mapping + and isinstance(updated_mapping[k], dict) + and isinstance(v, dict) + ): + updated_mapping[k] = deep_update(updated_mapping[k], v) + else: + updated_mapping[k] = v + return updated_mapping + + def inspect_signature(func: Callable[..., Any]) -> inspect.Signature: if sys.version_info >= (3, 10): signature = inspect.signature(func, eval_str=True) From f4b7da8c27b5eb36f663e7c30786875c9fdcf811 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 24 Apr 2024 19:01:27 +0000 Subject: [PATCH 08/15] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20f?= =?UTF-8?q?ormat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- typer/pydantic_extension.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/typer/pydantic_extension.py b/typer/pydantic_extension.py index 903b34249d..e6f48c57a3 100644 --- a/typer/pydantic_extension.py +++ b/typer/pydantic_extension.py @@ -70,9 +70,7 @@ def wrapper(*args, **kwargs): # type: ignore[no-untyped-def] _, qualifier = annotation.__metadata__ for part in reversed(qualifier): kwarg_value = {part: kwarg_value} - raw_pydantic_objects = deep_update( - raw_pydantic_objects, kwarg_value - ) + raw_pydantic_objects = deep_update(raw_pydantic_objects, kwarg_value) for root_name, value in raw_pydantic_objects.items(): converted_kwargs[root_name] = pydantic_roots[root_name](**value) return callback(*args, **converted_kwargs) From a6a7004b758a08816c9ece6021e1bdb87bf2fdae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Patrick=20D=C3=BCggelin?= Date: Wed, 24 Apr 2024 21:05:00 +0200 Subject: [PATCH 09/15] :art: Make linter happier --- typer/pydantic_extension.py | 4 +++- typer/utils.py | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/typer/pydantic_extension.py b/typer/pydantic_extension.py index e6f48c57a3..1af113b0c8 100644 --- a/typer/pydantic_extension.py +++ b/typer/pydantic_extension.py @@ -1,5 +1,7 @@ import inspect -from typing import Annotated, Any, Callable, Dict, List +from typing import Any, Callable, Dict, List + +from typing_extensions import Annotated from .params import Option from .utils import deep_update, inspect_signature, lenient_issubclass diff --git a/typer/utils.py b/typer/utils.py index 325ba588fc..f66bdd680b 100644 --- a/typer/utils.py +++ b/typer/utils.py @@ -127,8 +127,8 @@ def lenient_issubclass( def deep_update( - mapping: dict[KeyType, Any], *updating_mappings: dict[KeyType, Any] -) -> dict[KeyType, Any]: + mapping: Dict[KeyType, Any], *updating_mappings: Dict[KeyType, Any] +) -> Dict[KeyType, Any]: # Copied from pydantic because they don't expose it publicly: # https://github.com/pydantic/pydantic/blob/26129479a06960af9d02d3a948e51985fe59ed4b/pydantic/_internal/_utils.py#L103 updated_mapping = mapping.copy() From 24a6215365af44262e6a08e83333ecbf3bb26784 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Patrick=20D=C3=BCggelin?= Date: Wed, 24 Apr 2024 21:45:48 +0200 Subject: [PATCH 10/15] :bug: Fix forward refs for Python <= 3.9 --- typer/pydantic_extension.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/typer/pydantic_extension.py b/typer/pydantic_extension.py index 1af113b0c8..5573e299a7 100644 --- a/typer/pydantic_extension.py +++ b/typer/pydantic_extension.py @@ -78,4 +78,6 @@ def wrapper(*args, **kwargs): # type: ignore[no-untyped-def] return callback(*args, **converted_kwargs) wrapper.__signature__ = extended_signature # type: ignore + # Copy annotations to make forward references work in Python <= 3.9 + wrapper.__annotations__ = {k: v.annotation for k, v in extended_signature.parameters.items()} return wrapper From 3e37bb5e1152a3497cd9a61c6a919dd8660ce250 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 24 Apr 2024 19:46:27 +0000 Subject: [PATCH 11/15] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20f?= =?UTF-8?q?ormat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- typer/pydantic_extension.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/typer/pydantic_extension.py b/typer/pydantic_extension.py index 5573e299a7..b9bba8aa2d 100644 --- a/typer/pydantic_extension.py +++ b/typer/pydantic_extension.py @@ -79,5 +79,7 @@ def wrapper(*args, **kwargs): # type: ignore[no-untyped-def] wrapper.__signature__ = extended_signature # type: ignore # Copy annotations to make forward references work in Python <= 3.9 - wrapper.__annotations__ = {k: v.annotation for k, v in extended_signature.parameters.items()} + wrapper.__annotations__ = { + k: v.annotation for k, v in extended_signature.parameters.items() + } return wrapper From 05fce89dadde60de4a0eed161cb65d7dda575254 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Patrick=20D=C3=BCggelin?= Date: Wed, 24 Apr 2024 22:41:11 +0200 Subject: [PATCH 12/15] :white_check_mark: Tutorial and test nested model --- .../parameter_types/pydantic/tutorial002.py | 24 +++++++++++++ .../test_pydantic/test_tutorial002.py | 34 +++++++++++++++++++ 2 files changed, 58 insertions(+) create mode 100644 docs_src/parameter_types/pydantic/tutorial002.py create mode 100644 tests/test_tutorial/test_parameter_types/test_pydantic/test_tutorial002.py diff --git a/docs_src/parameter_types/pydantic/tutorial002.py b/docs_src/parameter_types/pydantic/tutorial002.py new file mode 100644 index 0000000000..cd8b75c65a --- /dev/null +++ b/docs_src/parameter_types/pydantic/tutorial002.py @@ -0,0 +1,24 @@ +from typing import List, Optional + +import typer + +import pydantic + + +class Pet(pydantic.BaseModel): + name: str + species: str + + +class Person(pydantic.BaseModel): + name: str + age: Optional[float] = None + pet: Pet + + +def main(person: Person): + print(person, type(person)) + + +if __name__ == "__main__": + typer.run(main) diff --git a/tests/test_tutorial/test_parameter_types/test_pydantic/test_tutorial002.py b/tests/test_tutorial/test_parameter_types/test_pydantic/test_tutorial002.py new file mode 100644 index 0000000000..fedb8aa6f9 --- /dev/null +++ b/tests/test_tutorial/test_parameter_types/test_pydantic/test_tutorial002.py @@ -0,0 +1,34 @@ +import subprocess +import sys + +import typer +from typer.testing import CliRunner + +from docs_src.parameter_types.pydantic import tutorial002 as mod + +runner = CliRunner() + +app = typer.Typer() +app.command()(mod.main) + + +def test_help(): + result = runner.invoke(app, ["--help"]) + assert result.exit_code == 0 + + +def test_parse_pydantic_model(): + result = runner.invoke(app, ["--person.name", "Jeff", "--person.pet.name", "Lassie", "--person.pet.species", "dog"]) + assert ( + "name='Jeff' age=None pet=Pet(name='Lassie', species='dog') " + in result.output + ) + + +def test_script(): + result = subprocess.run( + [sys.executable, "-m", "coverage", "run", mod.__file__, "--help"], + capture_output=True, + encoding="utf-8", + ) + assert "Usage" in result.stdout From 6c76ab9437ee91085ab47b4d1bdd8447ccc712b4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 24 Apr 2024 20:41:38 +0000 Subject: [PATCH 13/15] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20f?= =?UTF-8?q?ormat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs_src/parameter_types/pydantic/tutorial002.py | 2 +- .../test_pydantic/test_tutorial002.py | 12 +++++++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/docs_src/parameter_types/pydantic/tutorial002.py b/docs_src/parameter_types/pydantic/tutorial002.py index cd8b75c65a..1bdf52cfdb 100644 --- a/docs_src/parameter_types/pydantic/tutorial002.py +++ b/docs_src/parameter_types/pydantic/tutorial002.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Optional import typer diff --git a/tests/test_tutorial/test_parameter_types/test_pydantic/test_tutorial002.py b/tests/test_tutorial/test_parameter_types/test_pydantic/test_tutorial002.py index fedb8aa6f9..a3132aee35 100644 --- a/tests/test_tutorial/test_parameter_types/test_pydantic/test_tutorial002.py +++ b/tests/test_tutorial/test_parameter_types/test_pydantic/test_tutorial002.py @@ -18,7 +18,17 @@ def test_help(): def test_parse_pydantic_model(): - result = runner.invoke(app, ["--person.name", "Jeff", "--person.pet.name", "Lassie", "--person.pet.species", "dog"]) + result = runner.invoke( + app, + [ + "--person.name", + "Jeff", + "--person.pet.name", + "Lassie", + "--person.pet.species", + "dog", + ], + ) assert ( "name='Jeff' age=None pet=Pet(name='Lassie', species='dog') " in result.output From ebb58772d751036ae79ed5e86e1389835e4c9e95 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Patrick=20D=C3=BCggelin?= Date: Thu, 25 Apr 2024 14:47:31 +0200 Subject: [PATCH 14/15] :white_check_mark: Add tests for missing pydantic --- .../test_pydantic/test_tutorial001.py | 16 ++++++++++++++++ typer/pydantic_extension.py | 10 +++++----- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/tests/test_tutorial/test_parameter_types/test_pydantic/test_tutorial001.py b/tests/test_tutorial/test_parameter_types/test_pydantic/test_tutorial001.py index 1d0d81a60b..d4b7a20547 100644 --- a/tests/test_tutorial/test_parameter_types/test_pydantic/test_tutorial001.py +++ b/tests/test_tutorial/test_parameter_types/test_pydantic/test_tutorial001.py @@ -1,6 +1,7 @@ import subprocess import sys +import pytest import typer from typer.testing import CliRunner @@ -33,3 +34,18 @@ def test_script(): encoding="utf-8", ) assert "Usage" in result.stdout + + +def test_error_without_pydantic(): + pydantic = typer.pydantic_extension.pydantic + typer.pydantic_extension.pydantic = None + with pytest.raises( + RuntimeError, + match="Type not yet supported: ", + ): + runner.invoke( + app, + ["1", "--user.id", "2", "--user.name", "John Doe"], + catch_exceptions=False, + ) + typer.pydantic_extension.pydantic = pydantic diff --git a/typer/pydantic_extension.py b/typer/pydantic_extension.py index b9bba8aa2d..b177a054b7 100644 --- a/typer/pydantic_extension.py +++ b/typer/pydantic_extension.py @@ -14,17 +14,17 @@ PYDANTIC_FIELD_SEPARATOR = "." -def flatten_pydantic_model( +def _flatten_pydantic_model( model: "pydantic.BaseModel", ancestors: List[str] ) -> Dict[str, inspect.Parameter]: - if pydantic is None: - raise ImportError("Pydantic is required to use Pydantic models with Typer.") + # This function should only be called if pydantic is available + assert pydantic is not None pydantic_parameters = {} for field_name, field in model.model_fields.items(): qualifier = [*ancestors, field_name] sub_name = f"_pydantic_{'_'.join(qualifier)}" if lenient_issubclass(field.annotation, pydantic.BaseModel): - params = flatten_pydantic_model(field.annotation, qualifier) # type: ignore + params = _flatten_pydantic_model(field.annotation, qualifier) # type: ignore pydantic_parameters.update(params) else: default = ( @@ -51,7 +51,7 @@ def wrap_pydantic_callback(callback: Callable[..., Any]) -> Callable[..., Any]: other_parameters = {} for name, parameter in original_signature.parameters.items(): if lenient_issubclass(parameter.annotation, pydantic.BaseModel): - params = flatten_pydantic_model(parameter.annotation, [name]) + params = _flatten_pydantic_model(parameter.annotation, [name]) pydantic_parameters.update(params) pydantic_roots[name] = parameter.annotation else: From 12be77ecd65946027ff5caba9cc66169eda73310 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Patrick=20D=C3=BCggelin?= Date: Thu, 25 Apr 2024 15:18:22 +0200 Subject: [PATCH 15/15] :wrench: Don't report coverage for conditional import --- typer/pydantic_extension.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/typer/pydantic_extension.py b/typer/pydantic_extension.py index b177a054b7..f2f6a4f6a4 100644 --- a/typer/pydantic_extension.py +++ b/typer/pydantic_extension.py @@ -8,7 +8,7 @@ try: import pydantic -except ImportError: +except ImportError: # pragma: no cover pydantic = None # type: ignore PYDANTIC_FIELD_SEPARATOR = "."