POT icon indicating copy to clipboard operation
POT copied to clipboard

`ot.emd2()` does not work as expected with empty weights if the JAX backend is used

Open Francis-Hsu opened this issue 8 months ago • 3 comments

Describe the bug

Per documentation of ot.emd2(), uniform weights will be used if empty lists are passed as the arguments. However, doing so with the JAX backend will cause broadcasting issue.

To Reproduce

Simulate some data first:

import jax
from jax import numpy as jnp


key = jax.random.PRNGKey(1)
x = jax.random.normal(key, (100, 2))
y = jax.random.normal(key, (100, 2))

With numpy backend, the following works without an issue:

from opt_einsum import contract

M = contract('mi,ni->mn', x, y, backend='numpy') ** 2.
emt = np.empty((0))
Wass_dis = ot.emd2(emt, emt, M=M)
Wass_dis

However, errors occur once we switch to jnp:

M = contract('mi,ni->mn', x, y, backend='jax') ** 2.
emt = jnp.empty((0))
Wass_dis = ot.emd2(emt, emt, M=M)
Wass_dis

Partial error message:

File [c:\ProgramData\anaconda3\Lib\site-packages\ot\lp\__init__.py:567](file:///C:/ProgramData/anaconda3/Lib/site-packages/ot/lp/__init__.py:567), in emd2.<locals>.f(b)
    559     warnings.warn(
    560         "Input histogram consists of integer. The transport plan will be "
    561         "casted accordingly, possibly resulting in a loss of precision. "
   (...)
    564         stacklevel=2
    565     )
    566 G = nx.from_numpy(G, type_as=type_as)
--> 567 cost = nx.set_gradients(nx.from_numpy(cost, type_as=type_as),
    568                         (a0, b0, M0), (nx.from_numpy(u - np.mean(u), type_as=type_as),
    569                                        nx.from_numpy(v - np.mean(v), type_as=type_as), G))
    571 check_result(result_code)
    572 return cost

File [c:\ProgramData\anaconda3\Lib\site-packages\ot\backend.py:1392](file:///C:/ProgramData/anaconda3/Lib/site-packages/ot/backend.py:1392), in JaxBackend.set_gradients(self, val, inputs, grads)
   1389 ravelled_inputs, _ = ravel_pytree(inputs)
   1390 ravelled_grads, _ = ravel_pytree(grads)
-> 1392 aux = jnp.sum(ravelled_inputs * ravelled_grads) / 2
   1393 aux = aux - jax.lax.stop_gradient(aux)
   1395 val, = jax.tree_map(lambda z: z + aux, (val,))

File [c:\ProgramData\anaconda3\Lib\site-packages\jax\_src\numpy\array_methods.py:256](file:///C:/ProgramData/anaconda3/Lib/site-packages/jax/_src/numpy/array_methods.py:256), in _defer_to_unrecognized_arg.<locals>.deferring_binary_op(self, other)
    254 args = (other, self) if swap else (self, other)
    255 if isinstance(other, _accepted_binop_types):
--> 256   return binary_op(*args)
    257 # Note: don't use isinstance here, because we don't want to raise for
    258 # subclasses, e.g. NamedTuple objects that may override operators.
    259 if type(other) in _rejected_binop_types:

    [... skipping hidden 12 frame]

File [c:\ProgramData\anaconda3\Lib\site-packages\jax\_src\numpy\ufuncs.py:97](file:///C:/ProgramData/anaconda3/Lib/site-packages/jax/_src/numpy/ufuncs.py:97), in _maybe_bool_binop.<locals>.fn(x1, x2)
     95 def fn(x1, x2, /):
     96   x1, x2 = promote_args(numpy_fn.__name__, x1, x2)
---> 97   return lax_fn(x1, x2) if x1.dtype != np.bool_ else bool_lax_fn(x1, x2)

    [... skipping hidden 7 frame]

File [c:\ProgramData\anaconda3\Lib\site-packages\jax\_src\lax\lax.py:1591](file:///C:/ProgramData/anaconda3/Lib/site-packages/jax/_src/lax/lax.py:1591), in broadcasting_shape_rule(name, *avals)
   1589       result_shape.append(non_1s[0])
   1590     else:
-> 1591       raise TypeError(f'{name} got incompatible shapes for broadcasting: '
   1592                       f'{", ".join(map(str, map(tuple, shapes)))}.')
   1594 return tuple(result_shape)

TypeError: mul got incompatible shapes for broadcasting: (10000,), (10200,).

Possible solution:

This problem can be avoided if we generate the uniform weight by ourselves:

M = contract('mi,ni->mn', x, y, backend='jax') ** 2.
emt0 = jnp.ones((M.shape[0],)) / M.shape[0]
emt1 = jnp.ones((M.shape[1],)) / M.shape[1]
Wass_dis = ot.emd2(emt0, emt1, M=M)
Wass_dis # correct result

Environment (please complete the following information):

  • OS (e.g. MacOS, Windows, Linux): Windows
  • Python version: 3.11.4
  • How was POT installed (source, pip, conda): pip

Output of the following code snippet:

import platform; print(platform.platform())
import sys; print("Python", sys.version)
import numpy; print("NumPy", numpy.__version__)
import scipy; print("SciPy", scipy.__version__)
import ot; print("POT", ot.__version__)
Windows-10-10.0.22621-SP0
Python 3.11.4 | packaged by Anaconda, Inc. | (main, Jul  5 2023, 13:38:37) [MSC v.1916 64 bit (AMD64)]
NumPy 1.24.3
SciPy 1.10.1
POT 0.9.1

Francis-Hsu avatar Oct 15 '23 22:10 Francis-Hsu

Hello @Francis-Hsu and thanks for the feedback. Could you do a quick check and see if there is a bug when you provide actual empty python list (a=[]) instead of empty jax arrays?

Unless I'm mistaken the documentation states "empty list" and the function should handle this well for any backend.

rflamary avatar Oct 16 '23 13:10 rflamary

also note that for the new API wheights are now optional and there is no need for emty lists:

Wass_dis = ot.solve(M).value

rflamary avatar Oct 16 '23 13:10 rflamary

Hello @Francis-Hsu and thanks for the feedback. Could you do a quick check and see if there is a bug when you provide actual empty python list (a=[]) instead of empty jax arrays?

Unless I'm mistaken the documentation states "empty list" and the function should handle this well for any backend.

Hi @rflamary. Thank you for the feedback. If I use ot.emd2([], [], M=M) I will get the type checking error:

ValueError: All array should be from the same type/backend. Current types are : [<class 'jaxlib.xla_extension.ArrayImpl'>, <class 'numpy.ndarray'>, <class 'numpy.ndarray'>]

But indeed the ot.solve(M) interface is much more convenient. I didn't know about it until now :P

Francis-Hsu avatar Oct 16 '23 17:10 Francis-Hsu

This one should be fixed in #606

rflamary avatar Mar 04 '24 12:03 rflamary