adam
adam copied to clipboard
Handle transition from jax v0.6.0 to jax v0.7.0
The tests action is failing: https://github.com/ami-iit/adam/actions/runs/16558803260/job/46824485937 with
AttributeError: jax.dlpack.to_dlpack was deprecated in JAX v0.6.0 and removed in JAX v0.7.0. Please use the newer DLPack API based on __dlpack__ and __dlpack_device__ instead. Typically, you can pass a JAX array directly to the `from_dlpack` function of another framework without using `to_dlpack`.
This is due to the recent update in jax (https://docs.jax.dev/en/latest/changelog.html#jax-0-7-0-july-22-2025) and jax2torch is not updated.
This is happening in the batched pytorch interface that uses jax to vectorize the rbda functions and convert to pytorch using jax2torch.
3 possible solutions:
- make
jax2torchcompatible with PR - implement the logic converting jax -> pytorch
- implement the batched functions using pure pytorch types (this would be the optimal solution, but it would take some effort - maybe related to #98). This does not explicitly fix the transition, but would solve the general issue.