jax
jax copied to clipboard
Support dlpack pinned host memory device type (kDLCPUPinned / kDLCUDAHost)
Description
Currently Jax doesn't support converting a dlpack array that's in pinned host memory. It would be nice to support this device type so you can, say, convert Torch arrays in pinned host memory to Jax.
Example:
import jax
import jax.dlpack
import torch
a = torch.zeros((32,), device=torch.device('cpu')).pin_memory()
a_jax = jax.dlpack.from_dlpack(a)
results in
Traceback (most recent call last):
File ".../test_pin.py", line 7, in <module>
a_jax = jax.dlpack.from_dlpack(a)
File ".../lib/python3.10/site-packages/jax/_src/dlpack.py", line 275, in from_dlpack
return _from_dlpack(external_array, device, copy)
File ".../lib/python3.10/site-packages/jax/_src/dlpack.py", line 204, in _from_dlpack
dl_device_platform = {
KeyError: <DLDeviceType.kDLCPUPinned: 3>
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.28
jaxlib: 0.4.28
numpy: 1.26.4
python: 3.10.14 | packaged by conda-forge | (main, Mar 20 2024, 12:45:18) [GCC 12.3.0]
jax.devices (1 total, 1 local): [cuda(id=0)]
process_count: 1
platform: uname_result(system='Linux', release='5.15.0-1063', version='#69~20.04.1-Ubuntu SMP Fri May 10 19:20:12 UTC 2024', machine='x86_64')
$ nvidia-smi
Tue Jul 9 15:04:56 2024
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.183.01 Driver Version: 535.183.01 CUDA Version: 12.2 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA A100-SXM4-80GB On | 00000000:10:1C.0 Off | 0 |
| N/A 51C P0 94W / 400W | 425MiB / 81920MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
| 0 N/A N/A 3517632 C python 416MiB |
+---------------------------------------------------------------------------------------+
We are ready to support this on our side, but it looks like PyTorch does not currently use kDLCUDAHost, see pytorch/pytorch#136250.
PyTorch supports the device type, the C code is incomplete but https://github.com/pytorch/pytorch/blob/b020971e7806bba39aecf636e59e743911831ad8/torch/_tensor.py#L1758-L1759 fills in the blank.
>>> jax.dlpack.from_dlpack(torch.zeros((32,), device=torch.device('cpu')).pin_memory())
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "~/.pixi/envs/dev/lib/python3.12/site-packages/jax/_src/dlpack.py", line 276, in from_dlpack
return _from_dlpack(external_array, device, copy)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "~/.pixi/envs/dev/lib/python3.12/site-packages/jax/_src/dlpack.py", line 214, in _from_dlpack
raise TypeError(
TypeError: Array passed to from_dlpack is on unsupported device type (DLDeviceType: 3, array: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0.])
>>> jax.dlpack.from_dlpack(torch.zeros((32,), device=torch.device('cpu')))
Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)