Incompatibility between clip and torch.vmap
I stumbled into an edge case when trying to apply torch.vmap to some code I had rewritten to utilize array-api-compat. So far everything seems to work just fine, with the exception of clip. Here's a minimal example:
import array_api_compat.torch as xp
import torch
def apply_clip(a):
return torch.clip(a, min=0, max=30)
def apply_clip_compat(a):
return xp.clip(a, min=0, max=30)
a = xp.asarray([[5.1, 2.0, 64.1, -1.5]])
print(apply_clip(a))
print(apply_clip_compat(a))
v1 = torch.vmap(apply_clip)
print(v1(a))
v2 = xp.vmap(apply_clip_compat)
print(v2(a))
Which raises the following error:
[user@domain ~]$ python test_clip.py
tensor([[ 5.1000, 2.0000, 30.0000, 0.0000]])
tensor([[ 5.1000, 2.0000, 30.0000, 0.0000]])
tensor([[ 5.1000, 2.0000, 30.0000, 0.0000]])
Traceback (most recent call last):
File "test_clip.py", line 22, in <module>
print(v2(a))
~~^^^
File ".venv/lib/python3.13/site-packages/torch/_functorch/apis.py", line 202, in wrapped
return vmap_impl(
func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs
)
File ".venv/lib/python3.13/site-packages/torch/_functorch/vmap.py", line 334, in vmap_impl
return _flat_vmap(
func,
...<6 lines>...
**kwargs,
)
File ".venv/lib/python3.13/site-packages/torch/_functorch/vmap.py", line 484, in _flat_vmap
batched_outputs = func(*batched_inputs, **kwargs)
File "test_clip.py", line 10, in apply_clip_compat
return xp.clip(a, min=0, max=30)
~~~~~~~^^^^^^^^^^^^^^^^^^
File ".venv/lib/python3.13/site-packages/array_api_compat/_internal.py", line 35, in wrapped_f
return f(*args, xp=xp, **kwargs)
File ".venv/lib/python3.13/site-packages/array_api_compat/common/_aliases.py", line 424, in clip
out[()] = x
~~~^^^^
RuntimeError: vmap: inplace arithmetic(self, *extra_args) is not possible because there exists a Tensor `other` in extra_args that has more elements than `self`. This happened due to `other` being vmapped over but `self` not being vmapped over in a vmap. Please try to use out-of-place operators instead of inplace arithmetic. If said operator is being called inside the PyTorch framework, please file a bug report instead.
I totally understand if full support for torch.vmap is out of scope, but figured it might be worth raising the issue in case there's something which requires fixing.
that is a shame! https://github.com/data-apis/array-api-compat/blob/be9eff7e382997608d3fa8c87fd559c0810dd366/array_api_compat/common/_aliases.py#L438-L439
For most things, [()] works to get NumPy scalars for NumPy, and as a no-op for other arrays. But with vmap it seems the no-op assumption breaks down.
In the long-run, I think this calls for an as_scalar_if_numpy helper to deal with the pesky NumPy special-case in a safe way. Not sure if there is a simpler workaround in the short term. WDYT @ev-br ?
In array-api-compat specifically, I think it's a matter of a simple refactoring to move the [()] out of common/ and into "numpy/" and other specific folders.
EDIT: one other offender is common/sign.
Good shout!
It would be great to understand better the origin of this error.
If I replace out[()] = x with out = torch.clone(x) at https://github.com/data-apis/array-api-compat/blob/be9eff7e382997608d3fa8c87fd559c0810dd366/array_api_compat/common/_aliases.py#L424, it immediately fails a few lines below at https://github.com/data-apis/array-api-compat/blob/be9eff7e382997608d3fa8c87fd559c0810dd366/array_api_compat/common/_aliases.py#L430
While this all can be worked around, it begs the question what is the deeper reason behind all this. Surely it cannot be that indexing operations are incompatible with vmap?
Ah, OK, This seems to be the answer: https://discuss.pytorch.org/t/vmap-inplace-arithmetic-error/196556/2
you can’t have a write to an existing tensor
So indeed mutating indexing operations are incompatible with torch.vmap. Which means array_api_compat.torch.clip should be rewritten in terms of torch.clamp instead of emulating it as it does ATM.
A tentative fix is in https://github.com/data-apis/array-api-compat/pull/353, but I'm not sure it's entirely correct because of promotion rules. Thoughts?