jax icon indicating copy to clipboard operation
jax copied to clipboard

RuntimeError with from_dlpack(copy=True)

Open haohuanw opened this issue 1 month ago • 1 comments

Description

import jax
import jax.numpy as jnp

print(f"JAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")
print()

jax_gpu_array = jax.device_put(jnp.array([1.0, 2.0, 3.0, 4.0], dtype=jnp.float32), jax.devices("gpu")[0])

# Test copy=True (expected to fail with kMutableZeroCopy error)
try:
    result = jax.dlpack.from_dlpack(jax_gpu_array, copy=True)
    print(f"from_dlpack(copy=True) SUCCEEDED: {result}")
except Exception as e:
    print(f"from_dlpack(copy=True) FAILED: {type(e).__name__}: {e}")

# Test copy=False (expected to work)
try:
    result = jax.dlpack.from_dlpack(jax_gpu_array, copy=None)
    print(f"from_dlpack(copy=False) SUCCEEDED: {result}")
except Exception as e:
    print(f"from_dlpack(copy=False) FAILED: {type(e).__name__}: {e}")

# Workaround: copy=False then jnp.array with copy=True
try:
    result_nocopy = jax.dlpack.from_dlpack(jax_gpu_array, copy=False)
    result_copied = jnp.array(result_nocopy, copy=True)
    print(f"Workaround (copy=False + jnp.array copy=True) SUCCEEDED: {result_copied}")
except Exception as e:
    print(f"Workaround FAILED: {type(e).__name__}: {e}")

output

JAX version: 0.8.1
Devices: [CudaDevice(id=0)]

from_dlpack(copy=True) FAILED: JaxRuntimeError: UNIMPLEMENTED: PJRT C API does not support HostBufferSemantics other than HostBufferSemantics::kImmutableOnlyDuringCall, HostBufferSemantics::kImmutableZeroCopy and HostBufferSemantics::kImmutableUntilTransferCompletes.
from_dlpack(copy=False) SUCCEEDED: [1. 2. 3. 4.]
Workaround (copy=False + jnp.array copy=True) SUCCEEDED: [1. 2. 3. 4.]

i believe this is introduced from https://github.com/jax-ml/jax/commit/ec1e65e3438d790c47eb34e0d8320a0b28ba1fcd where copy=True code path changed. i am able to temporarily workaround and do old code path with following code

    result_nocopy = jax.dlpack.from_dlpack(jax_gpu_array, copy=False)
    result_copied = jnp.array(result_nocopy, copy=True)

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.8.1
jaxlib: 0.8.1
numpy:  2.2.6
python: 3.11.9 (main, Aug 14 2024, 05:07:28) [Clang 18.1.8 ]
device info: NVIDIA H100 80GB HBM3-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='h100-reserved-192-143', release='6.5.13-65-650-4141-22041-coreweave-amd64-85c45edc', version='#1 SMP PREEMPT_DYNAMIC Mon Oct 14 20:37:13 UTC 2024', machine='x86_64')

$ nvidia-smi
Mon Dec  8 01:41:28 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.195.03             Driver Version: 570.195.03     CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA H100 80GB HBM3          On  |   00000000:9B:00.0 Off |                    0 |
| N/A   28C    P0             84W /  700W |     555MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI              PID   Type   Process name                        GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A         3575769      C   python                                  546MiB |
+-----------------------------------------------------------------------------------------+

haohuanw avatar Dec 08 '25 01:12 haohuanw

The exception is raised in https://github.com/openxla/xla/blob/958489d78808fc152c192ab003b405a46c61aba8/xla/pjrt/c_api_client/pjrt_c_api_client.cc#L798

When considering the definition of HostBufferSemantics: https://github.com/openxla/xla/blob/958489d78808fc152c192ab003b405a46c61aba8/xla/pjrt/pjrt_client.h#L878 the value of host_buffer_semantics must be kMutableZeroCopy.

For a non-CPU platform kMutableZeroCopy is identical to kImmutableUntilTransferCompletes. So, the fix to this issue is about updating the call to BufferFromHostBuffer (in MakePjrtBuffer of jaxlib/dlpack.cc‎) as follows:

return device.client()->BufferFromHostBuffer(
      data, element_type, dimensions, byte_strides,
      (dlmt->dl_tensor.device.device_type == kDLCPU ? xla::PjRtClient::HostBufferSemantics::kMutableZeroCopy : xla::PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes),
      on_delete_callback, memory_space, /*device_layout=*/nullptr);

(not tested).

pearu avatar Dec 08 '25 10:12 pearu