Skip to content

Incompatibility between clip and torch.vmap #350

@TimothyEDawson

Description

@TimothyEDawson

I stumbled into an edge case when trying to apply torch.vmap to some code I had rewritten to utilize array-api-compat. So far everything seems to work just fine, with the exception of clip. Here's a minimal example:

import array_api_compat.torch as xp
import torch


def apply_clip(a):
    return torch.clip(a, min=0, max=30)


def apply_clip_compat(a):
    return xp.clip(a, min=0, max=30)


a = xp.asarray([[5.1, 2.0, 64.1, -1.5]])

print(apply_clip(a))
print(apply_clip_compat(a))

v1 = torch.vmap(apply_clip)
print(v1(a))

v2 = xp.vmap(apply_clip_compat)
print(v2(a))

Which raises the following error:

[user@domain ~]$ python test_clip.py 
tensor([[ 5.1000,  2.0000, 30.0000,  0.0000]])
tensor([[ 5.1000,  2.0000, 30.0000,  0.0000]])
tensor([[ 5.1000,  2.0000, 30.0000,  0.0000]])
Traceback (most recent call last):
  File "test_clip.py", line 22, in <module>
    print(v2(a))
          ~~^^^
  File ".venv/lib/python3.13/site-packages/torch/_functorch/apis.py", line 202, in wrapped
    return vmap_impl(
        func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs
    )
  File ".venv/lib/python3.13/site-packages/torch/_functorch/vmap.py", line 334, in vmap_impl
    return _flat_vmap(
        func,
    ...<6 lines>...
        **kwargs,
    )
  File ".venv/lib/python3.13/site-packages/torch/_functorch/vmap.py", line 484, in _flat_vmap
    batched_outputs = func(*batched_inputs, **kwargs)
  File "test_clip.py", line 10, in apply_clip_compat
    return xp.clip(a, min=0, max=30)
           ~~~~~~~^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.13/site-packages/array_api_compat/_internal.py", line 35, in wrapped_f
    return f(*args, xp=xp, **kwargs)
  File ".venv/lib/python3.13/site-packages/array_api_compat/common/_aliases.py", line 424, in clip
    out[()] = x
    ~~~^^^^
RuntimeError: vmap: inplace arithmetic(self, *extra_args) is not possible because there exists a Tensor `other` in extra_args that has more elements than `self`. This happened due to `other` being vmapped over but `self` not being vmapped over in a vmap. Please try to use out-of-place operators instead of inplace arithmetic. If said operator is being called inside the PyTorch framework, please file a bug report instead.

I totally understand if full support for torch.vmap is out of scope, but figured it might be worth raising the issue in case there's something which requires fixing.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions