ott icon indicating copy to clipboard operation
ott copied to clipboard

Not passing `epsilon` with kernel matrix causes recursion error

Open michalk8 opened this issue 1 year ago • 1 comments

Code to reproduce; most likely introduce in #310 :

import jax.numpy as jnp
import ott
x = jnp.ones((10, 12))
ott.geometry.geometry.Geometry(kernel_matrix=x).cost_matrix

Traceback:

RecursionError                            Traceback (most recent call last)
Cell In [1], line 4
      2 import ott
      3 x = jnp.ones((10, 12))
----> 4 ott.geometry.geometry.Geometry(kernel_matrix=x).cost_matrix

File ~/Projects/ott/src/ott/geometry/geometry.py:111, in Geometry.cost_matrix(self)
    109   cost = -jnp.log(self._kernel_matrix + eps)
    110   cost *= self.inv_scale_cost
--> 111   return cost if self._epsilon_init is None else self.epsilon * cost
    112 return self._cost_matrix * self.inv_scale_cost

File ~/Projects/ott/src/ott/geometry/geometry.py:155, in Geometry.epsilon(self)
    152 @property
    153 def epsilon(self) -> float:
    154   """Epsilon regularization value."""
--> 155   return self._epsilon.target

File ~/Projects/ott/src/ott/geometry/geometry.py:143, in Geometry._epsilon(self)
    141 use_mean_scale = rel is True or (rel is None and target is None)
    142 if scale_eps is None and use_mean_scale:
--> 143   scale_eps = jax.lax.stop_gradient(self.mean_cost_matrix)
    145 if isinstance(self._epsilon_init, epsilon_scheduler.Epsilon):
    146   return self._epsilon_init.set(scale_epsilon=scale_eps)

File ~/Projects/ott/src/ott/geometry/geometry.py:123, in Geometry.mean_cost_matrix(self)
    120 @property
    121 def mean_cost_matrix(self) -> float:
    122   """Mean of the :attr:`cost_matrix`."""
--> 123   tmp = self._masked_geom().apply_cost(self._n_normed_ones).squeeze()
    124   return jnp.sum(tmp * self._m_normed_ones)

File ~/Projects/ott/src/ott/geometry/geometry.py:576, in Geometry.apply_cost(self, arr, axis, fn, **kwargs)
    573   arr = arr.reshape(-1, 1)
    575 app = functools.partial(self._apply_cost_to_vec, axis=axis, fn=fn, **kwargs)
--> 576 return jax.vmap(app, in_axes=1, out_axes=1)(arr)

    [... skipping hidden 3 frame]

File ~/Projects/ott/src/ott/geometry/geometry.py:596, in Geometry._apply_cost_to_vec(self, vec, axis, fn, **_)
    578 def _apply_cost_to_vec(
    579     self,
    580     vec: jnp.ndarray,
   (...)
    583     **_: Any,
    584 ) -> jnp.ndarray:
    585   """Apply ``[num_a, num_b]`` fn(cost) (or transpose) to vector.
    586 
    587   Args:
   (...)
    594     A jnp.ndarray corresponding to cost x vector
    595   """
--> 596   matrix = self.cost_matrix.T if axis == 0 else self.cost_matrix
    597   matrix = fn(matrix) if fn is not None else matrix
    598   return jnp.dot(matrix, vec)

File ~/Projects/ott/src/ott/geometry/geometry.py:111, in Geometry.cost_matrix(self)
    109   cost = -jnp.log(self._kernel_matrix + eps)
    110   cost *= self.inv_scale_cost
--> 111   return cost if self._epsilon_init is None else self.epsilon * cost
    112 return self._cost_matrix * self.inv_scale_cost

File ~/Projects/ott/src/ott/geometry/geometry.py:155, in Geometry.epsilon(self)
    152 @property
    153 def epsilon(self) -> float:
    154   """Epsilon regularization value."""
--> 155   return self._epsilon.target

File ~/Projects/ott/src/ott/geometry/geometry.py:143, in Geometry._epsilon(self)
    141 use_mean_scale = rel is True or (rel is None and target is None)
    142 if scale_eps is None and use_mean_scale:
--> 143   scale_eps = jax.lax.stop_gradient(self.mean_cost_matrix)
    145 if isinstance(self._epsilon_init, epsilon_scheduler.Epsilon):
    146   return self._epsilon_init.set(scale_epsilon=scale_eps)

File ~/Projects/ott/src/ott/geometry/geometry.py:123, in Geometry.mean_cost_matrix(self)
    120 @property
    121 def mean_cost_matrix(self) -> float:
    122   """Mean of the :attr:`cost_matrix`."""
--> 123   tmp = self._masked_geom().apply_cost(self._n_normed_ones).squeeze()
    124   return jnp.sum(tmp * self._m_normed_ones)

File ~/Projects/ott/src/ott/geometry/geometry.py:576, in Geometry.apply_cost(self, arr, axis, fn, **kwargs)
    573   arr = arr.reshape(-1, 1)
    575 app = functools.partial(self._apply_cost_to_vec, axis=axis, fn=fn, **kwargs)
--> 576 return jax.vmap(app, in_axes=1, out_axes=1)(arr)

    [... skipping hidden 3 frame]

File ~/Projects/ott/src/ott/geometry/geometry.py:596, in Geometry._apply_cost_to_vec(self, vec, axis, fn, **_)
    578 def _apply_cost_to_vec(
    579     self,
    580     vec: jnp.ndarray,
   (...)
    583     **_: Any,
    584 ) -> jnp.ndarray:
    585   """Apply ``[num_a, num_b]`` fn(cost) (or transpose) to vector.
    586 
    587   Args:
   (...)
    594     A jnp.ndarray corresponding to cost x vector
    595   """
--> 596   matrix = self.cost_matrix.T if axis == 0 else self.cost_matrix
    597   matrix = fn(matrix) if fn is not None else matrix
    598   return jnp.dot(matrix, vec)

    [... skipping similar frames: Geometry._epsilon at line 143 (294 times), Geometry.cost_matrix at line 111 (294 times), Geometry.epsilon at line 155 (294 times), Geometry.mean_cost_matrix at line 123 (294 times), Geometry._apply_cost_to_vec at line 596 (293 times), Geometry.apply_cost at line 576 (293 times), WrappedFun.call_wrapped at line 165 (293 times), api_boundary.<locals>.reraise_with_filtered_traceback at line 166 (293 times), vmap.<locals>.vmap_f at line 1773 (293 times)]

File ~/Projects/ott/src/ott/geometry/geometry.py:576, in Geometry.apply_cost(self, arr, axis, fn, **kwargs)
    573   arr = arr.reshape(-1, 1)
    575 app = functools.partial(self._apply_cost_to_vec, axis=axis, fn=fn, **kwargs)
--> 576 return jax.vmap(app, in_axes=1, out_axes=1)(arr)

    [... skipping hidden 3 frame]

File ~/Projects/ott/src/ott/geometry/geometry.py:596, in Geometry._apply_cost_to_vec(self, vec, axis, fn, **_)
    578 def _apply_cost_to_vec(
    579     self,
    580     vec: jnp.ndarray,
   (...)
    583     **_: Any,
    584 ) -> jnp.ndarray:
    585   """Apply ``[num_a, num_b]`` fn(cost) (or transpose) to vector.
    586 
    587   Args:
   (...)
    594     A jnp.ndarray corresponding to cost x vector
    595   """
--> 596   matrix = self.cost_matrix.T if axis == 0 else self.cost_matrix
    597   matrix = fn(matrix) if fn is not None else matrix
    598   return jnp.dot(matrix, vec)

File ~/Projects/ott/src/ott/geometry/geometry.py:111, in Geometry.cost_matrix(self)
    109   cost = -jnp.log(self._kernel_matrix + eps)
    110   cost *= self.inv_scale_cost
--> 111   return cost if self._epsilon_init is None else self.epsilon * cost
    112 return self._cost_matrix * self.inv_scale_cost

File ~/Projects/ott/src/ott/geometry/geometry.py:155, in Geometry.epsilon(self)
    152 @property
    153 def epsilon(self) -> float:
    154   """Epsilon regularization value."""
--> 155   return self._epsilon.target

File ~/Projects/ott/src/ott/geometry/geometry.py:143, in Geometry._epsilon(self)
    141 use_mean_scale = rel is True or (rel is None and target is None)
    142 if scale_eps is None and use_mean_scale:
--> 143   scale_eps = jax.lax.stop_gradient(self.mean_cost_matrix)
    145 if isinstance(self._epsilon_init, epsilon_scheduler.Epsilon):
    146   return self._epsilon_init.set(scale_epsilon=scale_eps)

File ~/Projects/ott/src/ott/geometry/geometry.py:123, in Geometry.mean_cost_matrix(self)
    120 @property
    121 def mean_cost_matrix(self) -> float:
    122   """Mean of the :attr:`cost_matrix`."""
--> 123   tmp = self._masked_geom().apply_cost(self._n_normed_ones).squeeze()
    124   return jnp.sum(tmp * self._m_normed_ones)

File ~/Projects/ott/src/ott/geometry/geometry.py:862, in Geometry._n_normed_ones(self)
    860 """Normalized array of shape ``[num_a,]``."""
    861 mask = self.src_mask
--> 862 arr = jnp.ones(self.shape[0]) if mask is None else mask
    863 return arr / jnp.sum(arr)

File ~/.mambaforge/envs/ott/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:2150, in ones(shape, dtype)
   2148 shape = canonicalize_shape(shape)
   2149 dtypes.check_user_dtype_supported(dtype, "ones")
-> 2150 return lax.full(shape, 1, _jnp_dtype(dtype))

    [... skipping hidden 17 frame]

File ~/.mambaforge/envs/ott/lib/python3.10/site-packages/jax/_src/config.py:241, in Config.define_bool_state.<locals>.get_state(self)
    240 def get_state(self):
--> 241   val = _thread_local_state.__dict__.get(name, unset)
    242   return val if val is not unset else self._read(name)

michalk8 avatar Mar 30 '23 18:03 michalk8