warp icon indicating copy to clipboard operation
warp copied to clipboard

[BUG] DLPack Interop with JAX fails if CPU buffer is not aligned to XLA's requirements

Open shi-eric opened this issue 10 months ago • 5 comments

Bug Description

Running test_dlpack.py on aarch64 with pip install -U "jax[cuda12]", we see the following errors:

======================================================================
FAIL: test_dlpack_warp_to_jax_cpu (__main__.TestDLPack.test_dlpack_warp_to_jax_cpu)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/eshi/code-projects/warp/warp/tests/unittest_utils.py", line 248, in test_func
    func(self, device, **kwargs)
  File "/home/eshi/code-projects/warp/warp/tests/interop/test_dlpack.py", line 393, in test_dlpack_warp_to_jax
    test.assertEqual(a.ptr, j1.unsafe_buffer_pointer())
AssertionError: 107405155888560 != 107405132074560

======================================================================
FAIL: test_dlpack_warp_to_jax_v2_cpu (__main__.TestDLPack.test_dlpack_warp_to_jax_v2_cpu)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/eshi/code-projects/warp/warp/tests/unittest_utils.py", line 248, in test_func
    func(self, device, **kwargs)
  File "/home/eshi/code-projects/warp/warp/tests/interop/test_dlpack.py", line 430, in test_dlpack_warp_to_jax_v2
    test.assertEqual(a.ptr, j1.unsafe_buffer_pointer())
AssertionError: 107405168471520 != 107405172665920

----------------------------------------------------------------------
Ran 33 tests in 5.864s

FAILED (failures=2)

System Information

Confirmed on a Jetson AGX Orin, tried multiple JAX versions from 0.5.1 to 0.4.38.

Python versions: 3.12.0 (miniforge), 3.10.12 (uv)

Also confirmed on 3.12.3 (uv) on x86-64....

shi-eric avatar Feb 24 '25 23:02 shi-eric

Update: I was missing these warnings:

test_dlpack_warp_to_jax_cpu (__main__.TestDLPack) ... /home/eshi/code-projects/warp/warp/tests/interop/test_dlpack.py:388: DeprecationWarning: Calling from_dlpack with a DLPack tensor is deprecated. The argument to from_dlpack should be an array from another framework that implements the __dlpack__ protocol.
  j1 = jax.dlpack.from_dlpack(wp.to_dlpack(a))
2025-02-24 16:09:57.565789: W external/xla/xla/python/dlpack.cc:347] DLPack buffer is not aligned (data at: 0xaaab02ced8a0). Creating a copy.
/home/eshi/code-projects/warp/warp/jax.py:164: DeprecationWarning: Calling from_dlpack with a DLPack tensor is deprecated. The argument to from_dlpack should be an array from another framework that implements the __dlpack__ protocol.

shi-eric avatar Feb 25 '25 00:02 shi-eric

https://github.com/jax-ml/jax/discussions/6055 shows us how to create a NumPy array with an alignment that JAX is happy with.

shi-eric avatar Feb 25 '25 00:02 shi-eric

This comment indicates that XLA needs 16-byte alignment: https://github.com/openxla/xla/blob/b5c0101d998781261b2815ef03146e9ff28cbb32/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.cc#L771-L776

But this value can change depending on the value of EIGEN_MAX_ALIGN_BYTES:

https://github.com/openxla/xla/blob/main/xla/backends/cpu/alignment.h#L32

shi-eric avatar Feb 25 '25 06:02 shi-eric

@nvlukasz suggests adding a config option to control alignment, e.g. wp.config.cpu_align = 16

shi-eric avatar Feb 25 '25 06:02 shi-eric

I added dedaebce756d6f9380c182e78e080f4b74e3ddfd to temporarily avoid the test failure for now.

shi-eric avatar Feb 25 '25 06:02 shi-eric

Fixed in 01777f92872f90afa78da5801d43c5a7e2e01aa7

nvlukasz avatar Jun 11 '25 01:06 nvlukasz