adam icon indicating copy to clipboard operation
adam copied to clipboard

Handle transition from jax v0.6.0 to jax v0.7.0

Open Giulero opened this issue 5 months ago • 0 comments

The tests action is failing: https://github.com/ami-iit/adam/actions/runs/16558803260/job/46824485937 with

AttributeError: jax.dlpack.to_dlpack was deprecated in JAX v0.6.0 and removed in JAX v0.7.0. Please use the newer DLPack API based on __dlpack__ and __dlpack_device__ instead. Typically, you can pass a JAX array directly to the `from_dlpack` function of another framework without using `to_dlpack`.

This is due to the recent update in jax (https://docs.jax.dev/en/latest/changelog.html#jax-0-7-0-july-22-2025) and jax2torch is not updated.

This is happening in the batched pytorch interface that uses jax to vectorize the rbda functions and convert to pytorch using jax2torch.

3 possible solutions:

  • make jax2torch compatible with PR
  • implement the logic converting jax -> pytorch
  • implement the batched functions using pure pytorch types (this would be the optimal solution, but it would take some effort - maybe related to #98). This does not explicitly fix the transition, but would solve the general issue.

Giulero avatar Jul 28 '25 08:07 Giulero