array-api-compat icon indicating copy to clipboard operation
array-api-compat copied to clipboard

Incompatibility between clip and torch.vmap

Open TimothyEDawson opened this issue 4 months ago • 6 comments

I stumbled into an edge case when trying to apply torch.vmap to some code I had rewritten to utilize array-api-compat. So far everything seems to work just fine, with the exception of clip. Here's a minimal example:

import array_api_compat.torch as xp
import torch


def apply_clip(a):
    return torch.clip(a, min=0, max=30)


def apply_clip_compat(a):
    return xp.clip(a, min=0, max=30)


a = xp.asarray([[5.1, 2.0, 64.1, -1.5]])

print(apply_clip(a))
print(apply_clip_compat(a))

v1 = torch.vmap(apply_clip)
print(v1(a))

v2 = xp.vmap(apply_clip_compat)
print(v2(a))

Which raises the following error:

[user@domain ~]$ python test_clip.py 
tensor([[ 5.1000,  2.0000, 30.0000,  0.0000]])
tensor([[ 5.1000,  2.0000, 30.0000,  0.0000]])
tensor([[ 5.1000,  2.0000, 30.0000,  0.0000]])
Traceback (most recent call last):
  File "test_clip.py", line 22, in <module>
    print(v2(a))
          ~~^^^
  File ".venv/lib/python3.13/site-packages/torch/_functorch/apis.py", line 202, in wrapped
    return vmap_impl(
        func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs
    )
  File ".venv/lib/python3.13/site-packages/torch/_functorch/vmap.py", line 334, in vmap_impl
    return _flat_vmap(
        func,
    ...<6 lines>...
        **kwargs,
    )
  File ".venv/lib/python3.13/site-packages/torch/_functorch/vmap.py", line 484, in _flat_vmap
    batched_outputs = func(*batched_inputs, **kwargs)
  File "test_clip.py", line 10, in apply_clip_compat
    return xp.clip(a, min=0, max=30)
           ~~~~~~~^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.13/site-packages/array_api_compat/_internal.py", line 35, in wrapped_f
    return f(*args, xp=xp, **kwargs)
  File ".venv/lib/python3.13/site-packages/array_api_compat/common/_aliases.py", line 424, in clip
    out[()] = x
    ~~~^^^^
RuntimeError: vmap: inplace arithmetic(self, *extra_args) is not possible because there exists a Tensor `other` in extra_args that has more elements than `self`. This happened due to `other` being vmapped over but `self` not being vmapped over in a vmap. Please try to use out-of-place operators instead of inplace arithmetic. If said operator is being called inside the PyTorch framework, please file a bug report instead.

I totally understand if full support for torch.vmap is out of scope, but figured it might be worth raising the issue in case there's something which requires fixing.

TimothyEDawson avatar Sep 04 '25 17:09 TimothyEDawson

that is a shame! https://github.com/data-apis/array-api-compat/blob/be9eff7e382997608d3fa8c87fd559c0810dd366/array_api_compat/common/_aliases.py#L438-L439

For most things, [()] works to get NumPy scalars for NumPy, and as a no-op for other arrays. But with vmap it seems the no-op assumption breaks down.

In the long-run, I think this calls for an as_scalar_if_numpy helper to deal with the pesky NumPy special-case in a safe way. Not sure if there is a simpler workaround in the short term. WDYT @ev-br ?

lucascolley avatar Sep 04 '25 18:09 lucascolley

In array-api-compat specifically, I think it's a matter of a simple refactoring to move the [()] out of common/ and into "numpy/" and other specific folders.

EDIT: one other offender is common/sign.

ev-br avatar Sep 04 '25 18:09 ev-br

Good shout!

lucascolley avatar Sep 04 '25 18:09 lucascolley

It would be great to understand better the origin of this error. If I replace out[()] = x with out = torch.clone(x) at https://github.com/data-apis/array-api-compat/blob/be9eff7e382997608d3fa8c87fd559c0810dd366/array_api_compat/common/_aliases.py#L424, it immediately fails a few lines below at https://github.com/data-apis/array-api-compat/blob/be9eff7e382997608d3fa8c87fd559c0810dd366/array_api_compat/common/_aliases.py#L430

While this all can be worked around, it begs the question what is the deeper reason behind all this. Surely it cannot be that indexing operations are incompatible with vmap?

ev-br avatar Sep 08 '25 16:09 ev-br

Ah, OK, This seems to be the answer: https://discuss.pytorch.org/t/vmap-inplace-arithmetic-error/196556/2

you can’t have a write to an existing tensor

So indeed mutating indexing operations are incompatible with torch.vmap. Which means array_api_compat.torch.clip should be rewritten in terms of torch.clamp instead of emulating it as it does ATM.

ev-br avatar Sep 08 '25 16:09 ev-br

A tentative fix is in https://github.com/data-apis/array-api-compat/pull/353, but I'm not sure it's entirely correct because of promotion rules. Thoughts?

ev-br avatar Sep 08 '25 18:09 ev-br