array-api
array-api copied to clipboard
PyTorch 'meta' device needs per-device capabilities
PyTorch has a special device, meta, which is a dummy device with no underlying data. This is a very powerful testing tool.
However, PyTorch has not implemented special support for unknown shapes just for this device. As a result, unique etc. fail.
Proposal
Change capabilities() to capabilities(*, device=None), matching the signature of default_dtypes() and dtypes().
Change array_api_compat.torch:
def capabilities(*, device=None):
device = torch.get_default_device() if device is None else torch.device(device)
is_material = device.type != "meta"
return {
"boolean indexing": is_material,
"data-dependent shapes": is_material,
"max dimensions": 64,
}
xref https://github.com/data-apis/array-api-extra/pull/300
PyTorch has a special device,
meta, which is a dummy device with no underlying data. This is a very powerful testing tool.
+1. xref https://github.com/data-apis/array-api/discussions/777 with a bunch of interest and discussion on this kind of functionality.
@jakevdp Would you have any objections to adding a device kwarg with the default being None for this API? Would be good to get JAX's input on this.