array-api-compat icon indicating copy to clipboard operation
array-api-compat copied to clipboard

Automatically use the correct device in xp.clip with passed Python number literal as bounds

Open ogrisel opened this issue 6 months ago • 5 comments

I would like the following not to fail with PyTorch:

>>> import array_api_compat.torch  as xp
>>> data = xp.linspace(0, 1, num=5, device="mps")
>>> xp.clip(data, 0.1, 0.9)
Traceback (most recent call last):
  Cell In[4], line 1
    xp.clip(data, 0.1, 0.9)
  File ~/miniforge3/envs/dev/lib/python3.11/site-packages/array_api_compat/_internal.py:28 in wrapped_f
    return f(*args, xp=xp, **kwargs)
  File ~/miniforge3/envs/dev/lib/python3.11/site-packages/array_api_compat/common/_aliases.py:317 in clip
    ia = (out < a) | xp.isnan(a)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, mps:0 and cpu!

At the moment, we need to be overly verbose to use xp.clip with pytorch on non-cpu tensors:

>>> from array_api_compat import device
>>> device_ = device(data)
>>> xp.clip(data, xp.asarray(0.1, device=device_), xp.asarray(0.9, device=device_))
tensor([0.1000, 0.2500, 0.5000, 0.7500, 0.9000], device='mps:0')

ogrisel avatar Aug 08 '24 14:08 ogrisel