-
Notifications
You must be signed in to change notification settings - Fork 38
Open
Labels
bugSomething isn't workingSomething isn't working
Description
I stumbled into an edge case when trying to apply torch.vmap
to some code I had rewritten to utilize array-api-compat. So far everything seems to work just fine, with the exception of clip
. Here's a minimal example:
import array_api_compat.torch as xp
import torch
def apply_clip(a):
return torch.clip(a, min=0, max=30)
def apply_clip_compat(a):
return xp.clip(a, min=0, max=30)
a = xp.asarray([[5.1, 2.0, 64.1, -1.5]])
print(apply_clip(a))
print(apply_clip_compat(a))
v1 = torch.vmap(apply_clip)
print(v1(a))
v2 = xp.vmap(apply_clip_compat)
print(v2(a))
Which raises the following error:
[user@domain ~]$ python test_clip.py
tensor([[ 5.1000, 2.0000, 30.0000, 0.0000]])
tensor([[ 5.1000, 2.0000, 30.0000, 0.0000]])
tensor([[ 5.1000, 2.0000, 30.0000, 0.0000]])
Traceback (most recent call last):
File "test_clip.py", line 22, in <module>
print(v2(a))
~~^^^
File ".venv/lib/python3.13/site-packages/torch/_functorch/apis.py", line 202, in wrapped
return vmap_impl(
func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs
)
File ".venv/lib/python3.13/site-packages/torch/_functorch/vmap.py", line 334, in vmap_impl
return _flat_vmap(
func,
...<6 lines>...
**kwargs,
)
File ".venv/lib/python3.13/site-packages/torch/_functorch/vmap.py", line 484, in _flat_vmap
batched_outputs = func(*batched_inputs, **kwargs)
File "test_clip.py", line 10, in apply_clip_compat
return xp.clip(a, min=0, max=30)
~~~~~~~^^^^^^^^^^^^^^^^^^
File ".venv/lib/python3.13/site-packages/array_api_compat/_internal.py", line 35, in wrapped_f
return f(*args, xp=xp, **kwargs)
File ".venv/lib/python3.13/site-packages/array_api_compat/common/_aliases.py", line 424, in clip
out[()] = x
~~~^^^^
RuntimeError: vmap: inplace arithmetic(self, *extra_args) is not possible because there exists a Tensor `other` in extra_args that has more elements than `self`. This happened due to `other` being vmapped over but `self` not being vmapped over in a vmap. Please try to use out-of-place operators instead of inplace arithmetic. If said operator is being called inside the PyTorch framework, please file a bug report instead.
I totally understand if full support for torch.vmap
is out of scope, but figured it might be worth raising the issue in case there's something which requires fixing.
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working