jax icon indicating copy to clipboard operation
jax copied to clipboard

Support dlpack pinned host memory device type (kDLCPUPinned / kDLCUDAHost)

Open JesseFarebro opened this issue 1 year ago • 2 comments

Description

Currently Jax doesn't support converting a dlpack array that's in pinned host memory. It would be nice to support this device type so you can, say, convert Torch arrays in pinned host memory to Jax.

Example:

import jax
import jax.dlpack
import torch

a = torch.zeros((32,), device=torch.device('cpu')).pin_memory()

a_jax = jax.dlpack.from_dlpack(a)

results in

Traceback (most recent call last):
  File ".../test_pin.py", line 7, in <module>
    a_jax = jax.dlpack.from_dlpack(a)
  File ".../lib/python3.10/site-packages/jax/_src/dlpack.py", line 275, in from_dlpack
    return _from_dlpack(external_array, device, copy)
  File ".../lib/python3.10/site-packages/jax/_src/dlpack.py", line 204, in _from_dlpack
    dl_device_platform = {
KeyError: <DLDeviceType.kDLCPUPinned: 3>

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.28
jaxlib: 0.4.28
numpy:  1.26.4
python: 3.10.14 | packaged by conda-forge | (main, Mar 20 2024, 12:45:18) [GCC 12.3.0]
jax.devices (1 total, 1 local): [cuda(id=0)]
process_count: 1
platform: uname_result(system='Linux', release='5.15.0-1063', version='#69~20.04.1-Ubuntu SMP Fri May 10 19:20:12 UTC 2024', machine='x86_64')


$ nvidia-smi
Tue Jul  9 15:04:56 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.183.01             Driver Version: 535.183.01   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA A100-SXM4-80GB          On  | 00000000:10:1C.0 Off |                    0 |
| N/A   51C    P0              94W / 400W |    425MiB / 81920MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A   3517632      C   python                                      416MiB |
+---------------------------------------------------------------------------------------+

JesseFarebro avatar Jul 09 '24 15:07 JesseFarebro

We are ready to support this on our side, but it looks like PyTorch does not currently use kDLCUDAHost, see pytorch/pytorch#136250.

superbobry avatar Feb 17 '25 10:02 superbobry

PyTorch supports the device type, the C code is incomplete but https://github.com/pytorch/pytorch/blob/b020971e7806bba39aecf636e59e743911831ad8/torch/_tensor.py#L1758-L1759 fills in the blank.

>>> jax.dlpack.from_dlpack(torch.zeros((32,), device=torch.device('cpu')).pin_memory())
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "~/.pixi/envs/dev/lib/python3.12/site-packages/jax/_src/dlpack.py", line 276, in from_dlpack
    return _from_dlpack(external_array, device, copy)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/.pixi/envs/dev/lib/python3.12/site-packages/jax/_src/dlpack.py", line 214, in _from_dlpack
    raise TypeError(
TypeError: Array passed to from_dlpack is on unsupported device type (DLDeviceType: 3, array: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0.])
>>> jax.dlpack.from_dlpack(torch.zeros((32,), device=torch.device('cpu')))
Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],      dtype=float32)

SobhanMP avatar Jun 17 '25 20:06 SobhanMP