array-api-compat
array-api-compat copied to clipboard
Automatically use the correct device in xp.clip with passed Python number literal as bounds
I would like the following not to fail with PyTorch:
>>> import array_api_compat.torch as xp
>>> data = xp.linspace(0, 1, num=5, device="mps")
>>> xp.clip(data, 0.1, 0.9)
Traceback (most recent call last):
Cell In[4], line 1
xp.clip(data, 0.1, 0.9)
File ~/miniforge3/envs/dev/lib/python3.11/site-packages/array_api_compat/_internal.py:28 in wrapped_f
return f(*args, xp=xp, **kwargs)
File ~/miniforge3/envs/dev/lib/python3.11/site-packages/array_api_compat/common/_aliases.py:317 in clip
ia = (out < a) | xp.isnan(a)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, mps:0 and cpu!
At the moment, we need to be overly verbose to use xp.clip
with pytorch on non-cpu tensors:
>>> from array_api_compat import device
>>> device_ = device(data)
>>> xp.clip(data, xp.asarray(0.1, device=device_), xp.asarray(0.9, device=device_))
tensor([0.1000, 0.2500, 0.5000, 0.7500, 0.9000], device='mps:0')