diff --git a/hug/interface.py b/hug/interface.py index d3fc31d9..0b56c340 100644 --- a/hug/interface.py +++ b/hug/interface.py @@ -212,6 +212,7 @@ def __init__(self, route, function): for name, transform in self.interface.input_transformations.items() } else: + self.map_params = {} self.input_transformations = self.interface.input_transformations if "output" in route: @@ -323,7 +324,7 @@ def documentation(self, add_to=None): inputs = doc.setdefault("inputs", OrderedDict()) types = self.interface.spec.__annotations__ for argument in parameters: - kind = types.get(argument, text) + kind = types.get(self._remap_entry(argument), text) if getattr(kind, "directive", None) is True: continue @@ -340,6 +341,9 @@ def _rewrite_params(self, params): if interface_name in params: params[internal_name] = params.pop(interface_name) + def _remap_entry(self, interface_name): + return self.map_params.get(interface_name, interface_name) + @staticmethod def cleanup_parameters(parameters, exception=None): for _parameter, directive in parameters.items(): @@ -417,8 +421,7 @@ def __call__(self, *args, **kwargs): self.api.delete_context(context, errors=errors) return outputs(errors) if outputs else errors - if getattr(self, "map_params", None): - self._rewrite_params(kwargs) + self._rewrite_params(kwargs) try: result = self.interface(**kwargs) if self.transform: @@ -617,8 +620,7 @@ def exit_callback(message): elif add_options_to: pass_to_function[add_options_to].append(option) - if getattr(self, "map_params", None): - self._rewrite_params(pass_to_function) + self._rewrite_params(pass_to_function) try: if args: @@ -816,8 +818,7 @@ def call_function(self, parameters): parameters = { key: value for key, value in parameters.items() if key in self.all_parameters } - if getattr(self, "map_params", None): - self._rewrite_params(parameters) + self._rewrite_params(parameters) return self.interface(**parameters) diff --git a/tests/test_documentation.py b/tests/test_documentation.py index 0655f1ed..b528dee6 100644 --- a/tests/test_documentation.py +++ b/tests/test_documentation.py @@ -185,3 +185,11 @@ def marshtest() -> Returns(): doc = api.http.documentation() assert doc["handlers"]["/marshtest"]["POST"]["outputs"]["type"] == "Return docs" + +def test_map_params_documentation_preserves_type(): + @hug.get(map_params={"from": "from_mapped"}) + def map_params_test(from_mapped: hug.types.number): + pass + + doc = api.http.documentation() + assert doc["handlers"]["/map_params_test"]["GET"]["inputs"]["from"]["type"] == "A whole number" diff --git a/tox.ini b/tox.ini index abf21c4f..9428483f 100644 --- a/tox.ini +++ b/tox.ini @@ -8,7 +8,7 @@ deps= marshmallow3: marshmallow==3.0.0rc6 whitelist_externals=flake8 -commands=py.test --cov-report html --cov hug -n auto tests +commands=py.test --durations 3 --cov-report html --cov hug -n auto tests [testenv:py37-black] deps=