diff --git a/docs_src/parameter_types/pydantic/__init__.py b/docs_src/parameter_types/pydantic/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/docs_src/parameter_types/pydantic/tutorial001.py b/docs_src/parameter_types/pydantic/tutorial001.py new file mode 100644 index 0000000000..2265ec987c --- /dev/null +++ b/docs_src/parameter_types/pydantic/tutorial001.py @@ -0,0 +1,17 @@ +import typer + +import pydantic + + +class User(pydantic.BaseModel): + id: int + name: str = "Jane Doe" + + +def main(num: int, user: User): + print(num, type(num)) + print(user, type(user)) + + +if __name__ == "__main__": + typer.run(main) diff --git a/docs_src/parameter_types/pydantic/tutorial002.py b/docs_src/parameter_types/pydantic/tutorial002.py new file mode 100644 index 0000000000..1bdf52cfdb --- /dev/null +++ b/docs_src/parameter_types/pydantic/tutorial002.py @@ -0,0 +1,24 @@ +from typing import Optional + +import typer + +import pydantic + + +class Pet(pydantic.BaseModel): + name: str + species: str + + +class Person(pydantic.BaseModel): + name: str + age: Optional[float] = None + pet: Pet + + +def main(person: Person): + print(person, type(person)) + + +if __name__ == "__main__": + typer.run(main) diff --git a/pyproject.toml b/pyproject.toml index c9e793e1ed..1f1255559c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,7 @@ homepage = "https://github.com/tiangolo/typer" standard = [ "shellingham >=1.3.0", "rich >=10.11.0", + "pydantic >= 2.0.0", ] [tool.pdm] diff --git a/tests/test_tutorial/test_parameter_types/test_pydantic/__init__.py b/tests/test_tutorial/test_parameter_types/test_pydantic/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_tutorial/test_parameter_types/test_pydantic/test_tutorial001.py b/tests/test_tutorial/test_parameter_types/test_pydantic/test_tutorial001.py new file mode 100644 index 0000000000..d4b7a20547 --- /dev/null +++ b/tests/test_tutorial/test_parameter_types/test_pydantic/test_tutorial001.py @@ -0,0 +1,51 @@ +import subprocess +import sys + +import pytest +import typer +from typer.testing import CliRunner + +from docs_src.parameter_types.pydantic import tutorial001 as mod + +runner = CliRunner() + +app = typer.Typer() +app.command()(mod.main) + + +def test_help(): + result = runner.invoke(app, ["--help"]) + assert result.exit_code == 0 + + +def test_parse_pydantic_model(): + result = runner.invoke(app, ["1", "--user.id", "2", "--user.name", "John Doe"]) + assert "1 " in result.output + assert ( + "id=2 name='John Doe' " + in result.output + ) + + +def test_script(): + result = subprocess.run( + [sys.executable, "-m", "coverage", "run", mod.__file__, "--help"], + capture_output=True, + encoding="utf-8", + ) + assert "Usage" in result.stdout + + +def test_error_without_pydantic(): + pydantic = typer.pydantic_extension.pydantic + typer.pydantic_extension.pydantic = None + with pytest.raises( + RuntimeError, + match="Type not yet supported: ", + ): + runner.invoke( + app, + ["1", "--user.id", "2", "--user.name", "John Doe"], + catch_exceptions=False, + ) + typer.pydantic_extension.pydantic = pydantic diff --git a/tests/test_tutorial/test_parameter_types/test_pydantic/test_tutorial002.py b/tests/test_tutorial/test_parameter_types/test_pydantic/test_tutorial002.py new file mode 100644 index 0000000000..a3132aee35 --- /dev/null +++ b/tests/test_tutorial/test_parameter_types/test_pydantic/test_tutorial002.py @@ -0,0 +1,44 @@ +import subprocess +import sys + +import typer +from typer.testing import CliRunner + +from docs_src.parameter_types.pydantic import tutorial002 as mod + +runner = CliRunner() + +app = typer.Typer() +app.command()(mod.main) + + +def test_help(): + result = runner.invoke(app, ["--help"]) + assert result.exit_code == 0 + + +def test_parse_pydantic_model(): + result = runner.invoke( + app, + [ + "--person.name", + "Jeff", + "--person.pet.name", + "Lassie", + "--person.pet.species", + "dog", + ], + ) + assert ( + "name='Jeff' age=None pet=Pet(name='Lassie', species='dog') " + in result.output + ) + + +def test_script(): + result = subprocess.run( + [sys.executable, "-m", "coverage", "run", mod.__file__, "--help"], + capture_output=True, + encoding="utf-8", + ) + assert "Usage" in result.stdout diff --git a/typer/main.py b/typer/main.py index 9db26975ca..9c35cfdc79 100644 --- a/typer/main.py +++ b/typer/main.py @@ -16,7 +16,6 @@ from .completion import get_completion_inspect_parameters from .core import MarkupMode, TyperArgument, TyperCommand, TyperGroup, TyperOption from .models import ( - AnyType, ArgumentInfo, CommandFunctionType, CommandInfo, @@ -34,7 +33,8 @@ Required, TyperInfo, ) -from .utils import get_params_from_function +from .pydantic_extension import wrap_pydantic_callback +from .utils import get_params_from_function, lenient_issubclass try: import rich @@ -572,17 +572,18 @@ def get_command_from_info( use_help = inspect.getdoc(command_info.callback) else: use_help = inspect.cleandoc(use_help) + callback = wrap_pydantic_callback(command_info.callback) ( params, convertors, context_param_name, - ) = get_params_convertors_ctx_param_name_from_function(command_info.callback) + ) = get_params_convertors_ctx_param_name_from_function(callback) cls = command_info.cls or TyperCommand command = cls( name=name, context_settings=command_info.context_settings, callback=get_callback( - callback=command_info.callback, + callback=callback, params=params, convertors=convertors, context_param_name=context_param_name, @@ -788,12 +789,6 @@ def get_click_type( raise RuntimeError(f"Type not yet supported: {annotation}") # pragma: no cover -def lenient_issubclass( - cls: Any, class_or_tuple: Union[AnyType, Tuple[AnyType, ...]] -) -> bool: - return isinstance(cls, type) and issubclass(cls, class_or_tuple) - - def get_click_param( param: ParamMeta, ) -> Tuple[Union[click.Argument, click.Option], Any]: diff --git a/typer/pydantic_extension.py b/typer/pydantic_extension.py new file mode 100644 index 0000000000..f2f6a4f6a4 --- /dev/null +++ b/typer/pydantic_extension.py @@ -0,0 +1,85 @@ +import inspect +from typing import Any, Callable, Dict, List + +from typing_extensions import Annotated + +from .params import Option +from .utils import deep_update, inspect_signature, lenient_issubclass + +try: + import pydantic +except ImportError: # pragma: no cover + pydantic = None # type: ignore + +PYDANTIC_FIELD_SEPARATOR = "." + + +def _flatten_pydantic_model( + model: "pydantic.BaseModel", ancestors: List[str] +) -> Dict[str, inspect.Parameter]: + # This function should only be called if pydantic is available + assert pydantic is not None + pydantic_parameters = {} + for field_name, field in model.model_fields.items(): + qualifier = [*ancestors, field_name] + sub_name = f"_pydantic_{'_'.join(qualifier)}" + if lenient_issubclass(field.annotation, pydantic.BaseModel): + params = _flatten_pydantic_model(field.annotation, qualifier) # type: ignore + pydantic_parameters.update(params) + else: + default = ( + field.default if field.default is not pydantic.fields._Unset else ... + ) + typer_option = Option(f"--{PYDANTIC_FIELD_SEPARATOR.join(qualifier)}") + pydantic_parameters[sub_name] = inspect.Parameter( + sub_name, + inspect.Parameter.KEYWORD_ONLY, + annotation=Annotated[field.annotation, typer_option, qualifier], + default=default, + ) + return pydantic_parameters + + +def wrap_pydantic_callback(callback: Callable[..., Any]) -> Callable[..., Any]: + if pydantic is None: + return callback + + original_signature = inspect_signature(callback) + + pydantic_parameters = {} + pydantic_roots = {} + other_parameters = {} + for name, parameter in original_signature.parameters.items(): + if lenient_issubclass(parameter.annotation, pydantic.BaseModel): + params = _flatten_pydantic_model(parameter.annotation, [name]) + pydantic_parameters.update(params) + pydantic_roots[name] = parameter.annotation + else: + other_parameters[name] = parameter + + extended_signature = inspect.Signature( + [*other_parameters.values(), *pydantic_parameters.values()], + return_annotation=original_signature.return_annotation, + ) + + def wrapper(*args, **kwargs): # type: ignore[no-untyped-def] + converted_kwargs = kwargs.copy() + raw_pydantic_objects: Dict[str, Any] = {} + for kwarg_name, kwarg_value in kwargs.items(): + if kwarg_name in pydantic_parameters: + converted_kwargs.pop(kwarg_name) + annotation = pydantic_parameters[kwarg_name].annotation + _, qualifier = annotation.__metadata__ + for part in reversed(qualifier): + kwarg_value = {part: kwarg_value} + raw_pydantic_objects = deep_update(raw_pydantic_objects, kwarg_value) + for root_name, value in raw_pydantic_objects.items(): + converted_kwargs[root_name] = pydantic_roots[root_name](**value) + return callback(*args, **converted_kwargs) + + wrapper.__signature__ = extended_signature # type: ignore + # Copy annotations to make forward references work in Python <= 3.9 + wrapper.__annotations__ = { + k: v.annotation for k, v in extended_signature.parameters.items() + } + return wrapper diff --git a/typer/utils.py b/typer/utils.py index 2ba7bace45..f66bdd680b 100644 --- a/typer/utils.py +++ b/typer/utils.py @@ -1,12 +1,23 @@ import inspect import sys from copy import copy -from typing import Any, Callable, Dict, List, Tuple, Type, cast, get_type_hints +from typing import ( + Any, + Callable, + Dict, + List, + Tuple, + Type, + TypeVar, + Union, + cast, + get_type_hints, +) from typing_extensions import Annotated from ._typing import get_args, get_origin -from .models import ArgumentInfo, OptionInfo, ParameterInfo, ParamMeta +from .models import AnyType, ArgumentInfo, OptionInfo, ParameterInfo, ParamMeta def _param_type_to_user_string(param_type: Type[ParameterInfo]) -> str: @@ -106,11 +117,44 @@ def _split_annotation_from_typer_annotations( ] -def get_params_from_function(func: Callable[..., Any]) -> Dict[str, ParamMeta]: +def lenient_issubclass( + cls: Any, class_or_tuple: Union[AnyType, Tuple[AnyType, ...]] +) -> bool: + return isinstance(cls, type) and issubclass(cls, class_or_tuple) + + +KeyType = TypeVar("KeyType") + + +def deep_update( + mapping: Dict[KeyType, Any], *updating_mappings: Dict[KeyType, Any] +) -> Dict[KeyType, Any]: + # Copied from pydantic because they don't expose it publicly: + # https://github.com/pydantic/pydantic/blob/26129479a06960af9d02d3a948e51985fe59ed4b/pydantic/_internal/_utils.py#L103 + updated_mapping = mapping.copy() + for updating_mapping in updating_mappings: + for k, v in updating_mapping.items(): + if ( + k in updated_mapping + and isinstance(updated_mapping[k], dict) + and isinstance(v, dict) + ): + updated_mapping[k] = deep_update(updated_mapping[k], v) + else: + updated_mapping[k] = v + return updated_mapping + + +def inspect_signature(func: Callable[..., Any]) -> inspect.Signature: if sys.version_info >= (3, 10): signature = inspect.signature(func, eval_str=True) else: signature = inspect.signature(func) + return signature + + +def get_params_from_function(func: Callable[..., Any]) -> Dict[str, ParamMeta]: + signature = inspect_signature(func) type_hints = get_type_hints(func) params = {}