POT
POT copied to clipboard
Jax backend: jax.errors.TracerArrayConversionError
Describe the bug
As far as I understand it, I should be able to use this library paired with jax via your backend switching (depending on the input types)? However, I am getting a jax.errors.TracerArrayConversionError
which seems to be arising as POT is converting to numpy (not jax.numpy) in the backend (despite me giving only jax.numpy inputs).
To Reproduce
import jax.numpy as jnp
from jax import random, grad
import ot as pot
key = random.PRNGKey(0)
B = 10
key, subkey = random.split(key)
x = random.normal(subkey, (B, 1))
key, subkey = random.split(key)
y = random.normal(subkey, (B, 1))
def loss_fn(x, y):
costs = jnp.linalg.norm(x[:, None] - y[None, :], axis=-1)**2
pi = pot.emd(
jnp.ones(B) / B,
jnp.ones(B) / B,
costs)
return jnp.sum(pi * costs)
g = grad(loss_fn)(x, y)
print(g)
(note the problem isn't specific to grad. it also applies to; vmap
, jit
, ...)
Traceback (most recent call last):
File "/home/telfaralex/Documents/phdv2/code/sinterp/sinterp/tests/test_couplings.py", line 58, in test_grad
g = grad(loss_fn)(x, y)
^^^^^^^^^^^^^^^^^^^
File "/home/telfaralex/Documents/phdv2/code/sinterp/sinterp/tests/test_couplings.py", line 51, in loss_fn
pi = ot_fn(
^^^^^^
File "/home/telfaralex/miniconda3/lib/python3.11/site-packages/ot/lp/__init__.py", line 318, in emd
M, a, b = nx.to_numpy(M, a, b)
^^^^^^^^^^^^^^^^^^^^
File "/home/telfaralex/miniconda3/lib/python3.11/site-packages/ot/backend.py", line 260, in to_numpy
return [self._to_numpy(array) for array in arrays]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/telfaralex/miniconda3/lib/python3.11/site-packages/ot/backend.py", line 260, in <listcomp>
return [self._to_numpy(array) for array in arrays]
^^^^^^^^^^^^^^^^^^^^^
File "/home/telfaralex/miniconda3/lib/python3.11/site-packages/ot/backend.py", line 1439, in _to_numpy
return np.array(a)
^^^^^^^^^^^
jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float32[10,10]
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
Environment:
- OS (e.g. MacOS, Windows, Linux): Linux
- Python version: 3.11.4
- How was POT installed (source,
pip
,conda
): pip. v0.9.3