jax
jax copied to clipboard
Rotation.concatenate does not work for two single rotations
Description
Currently, jax fails to concatenate instances of jax.scipy.spatial.transform.Rotation correctly, when they are both single rotations. Code to reproduce:
import jax.numpy as jnp
from jax.scipy.spatial.transform import Rotation as jRotation
q1 = jnp.array([0.0, 0.0, 1.0, 0.0])
q2 = jnp.array([0.0, 0.0, 0.0, 1.0])
r1 = jRotation.from_quat(q1)
r2 = jRotation.from_quat(q2)
r3 = jRotation.concatenate([r1, r2])
print(r3.as_quat())
print(r3.as_rotvec())
Expected output:
[[0. 0. 1. 0.]
[0. 0. 0. 1.]]
[[0. 0. 3.1415927]
[0. 0. 0.]]
Current output:
[0. 0. 1. 0. 0. 0. 0. 1.]
[0. 0. 3.1415927]
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.31
jaxlib: 0.4.31
numpy: 2.1.0
python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='development-1', release='6.5.0-35-generic', version='#35~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Tue May 7 09:00:52 UTC 2', machine='x86_64')
$ nvidia-smi
Thu Aug 22 23:42:23 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 555.42.02 Driver Version: 555.42.02 CUDA Version: 12.5 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA GeForce RTX 3060 Off | 00000000:01:00.0 Off | N/A |
| 0% 44C P8 15W / 170W | 1012MiB / 12288MiB | 29% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| 0 N/A N/A 1567 G /usr/lib/xorg/Xorg 667MiB |
| 0 N/A N/A 2169 G cinnamon 47MiB |
| 0 N/A N/A 3937 G /usr/lib/firefox/firefox 0MiB |
| 0 N/A N/A 21524 G ...yOnDemand --variations-seed-version 91MiB |
| 0 N/A N/A 75918 G ...erProcess --variations-seed-version 134MiB |
| 0 N/A N/A 1284178 G ...96,262144 --variations-seed-version 19MiB |
+-----------------------------------------------------------------------------------------+
Hi - thanks for the report! The Rotation functionality has some implementation issues, and is a part of the package that we've identified (retroactively) as out-of-scope for JAX (see https://jax.readthedocs.io/en/latest/jep/18137-numpy-scipy-scope.html#scipy-spatial), and at some point in the future it will probably be removed.
My hope is that ongoing efforts to make scipy compatible with the Python array API will allow JAX users to replace these tools with using the scipy rotation code directly, although that's not yet possible.
In the meantime, is this an issue that you can work around?
Hey, thanks for answering and sorry for taking so long to get back to you. We can definitely work around this issue.
It is funny that you mention the array API. From what I can understand from the scipy issue (https://github.com/scipy/scipy/issues/18286) on the matter, they are hoping to "dispatch" this kind of operation to e.g. jax, when it comes to C/C++/Cython/Fortran implementations in Scipy. This seems to be quite different from your vision.
It is funny that you mention the array API. From what I can understand from the scipy issue (scipy/scipy#18286) on the matter, they are hoping to "dispatch" this kind of operation to e.g. jax, when it comes to C/C++/Cython/Fortran implementations in Scipy. This seems to be quite different from your vision.
I think you're misreading that issue: while it's true that some operations will be dispatched to other libraries, my read is that this is limited to special functions which cannot be efficiently implemented in terms of the array API standard. Rotation does not fall into this category: it is an object-oriented API around operations that are easily expressible in terms of the API standard, and scipy is hard at work rewriting such APIs in terms of the array API. This is tracked in https://github.com/scipy/scipy/issues/18867.