From fa190decf7248028e3a265f44195f054370325ca Mon Sep 17 00:00:00 2001 From: Nikhil Kalra <1368497+nikalra@users.noreply.github.com> Date: Fri, 10 Mar 2023 12:44:10 -0800 Subject: [PATCH] Support torch.amax and torch.amin (#1797) Support torch.amax and torch.amin --- .../converters/mil/frontend/torch/ops.py | 20 +++++++++++ .../mil/frontend/torch/test/test_torch_ops.py | 35 +++++++++++++++++++ 2 files changed, 55 insertions(+) diff --git a/coremltools/converters/mil/frontend/torch/ops.py b/coremltools/converters/mil/frontend/torch/ops.py index 640c21030..370f4401a 100644 --- a/coremltools/converters/mil/frontend/torch/ops.py +++ b/coremltools/converters/mil/frontend/torch/ops.py @@ -4307,6 +4307,26 @@ def max(context, node): context.add(values, torch_name=values_name) context.add(indices, torch_name=indices_name) +def _add_amax_amin(context, node, reduce_op): + # mimic functionality from https://pytorch.org/docs/stable/generated/torch.amax.html + # mimic functionality from https://pytorch.org/docs/stable/generated/torch.amin.html + assert len(node.outputs) == 1 + + all_inputs = _get_inputs(context, node, expected=[2, 3]) + _input = all_inputs[0] + dim = [all_inputs[1].val] if type(all_inputs[1].val) == int else [x for x in all_inputs[1].val] + keepdim = all_inputs[2] if len(all_inputs) == 3 else False + + context.add(reduce_op(x=_input, axes=dim, keep_dims=keepdim), torch_name=node.outputs[0]) + +@register_torch_op +def amax(context, node): + _add_amax_amin(context, node, mb.reduce_max) + +@register_torch_op +def amin(context, node): + _add_amax_amin(context, node, mb.reduce_min) + @register_torch_op def argsort(context, node): diff --git a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py index c23f31cc2..99bf5fff9 100644 --- a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py +++ b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py @@ -2865,6 +2865,41 @@ def forward(self, x, y): input_shapes, model, backend=backend, compute_unit=compute_unit ) +class TestAMaxAMin(TorchBaseTest): + @pytest.mark.parametrize( + "compute_unit, backend, input_shapes, mode, reduce_dim, keepdim", + itertools.product( + compute_units, + backends, + [ + [(2, 5, 7, 3)], + [(3, 2, 9)], + [(1,)], + ], + ["minimum", "maximum"], + [0, 1, 2, 3, [0, 1], [0, 1, 2], [0, 1, 2, 3]], + [True, False], + ), + ) + def test_minimum_maximum(self, compute_unit, backend, input_shapes, mode, reduce_dim, keepdim): + class TestModel(torch.nn.Module): + def forward(self, input): + if type(reduce_dim) == int: + reduce_dim_clamped = min(input.dim() - 1, reduce_dim) + else: + reduce_dim_clamped = reduce_dim[:input.dim()] + if mode == "minimum": + return torch.amin(input, reduce_dim_clamped, keepdim) + elif mode == "maximum": + return torch.amax(input, reduce_dim_clamped, keepdim) + else: + raise ValueError("Unsupported mode: {mode}".format(mode=mode)) + + model = TestModel() + self.run_compare_torch( + input_shapes, model, backend=backend, compute_unit=compute_unit + ) + class TestPoolSymbolicInput(TorchBaseTest): def test_max_pool(self):