ott
ott copied to clipboard
Not passing `epsilon` with kernel matrix causes recursion error
Code to reproduce; most likely introduce in #310 :
import jax.numpy as jnp
import ott
x = jnp.ones((10, 12))
ott.geometry.geometry.Geometry(kernel_matrix=x).cost_matrix
Traceback:
RecursionError Traceback (most recent call last)
Cell In [1], line 4
2 import ott
3 x = jnp.ones((10, 12))
----> 4 ott.geometry.geometry.Geometry(kernel_matrix=x).cost_matrix
File ~/Projects/ott/src/ott/geometry/geometry.py:111, in Geometry.cost_matrix(self)
109 cost = -jnp.log(self._kernel_matrix + eps)
110 cost *= self.inv_scale_cost
--> 111 return cost if self._epsilon_init is None else self.epsilon * cost
112 return self._cost_matrix * self.inv_scale_cost
File ~/Projects/ott/src/ott/geometry/geometry.py:155, in Geometry.epsilon(self)
152 @property
153 def epsilon(self) -> float:
154 """Epsilon regularization value."""
--> 155 return self._epsilon.target
File ~/Projects/ott/src/ott/geometry/geometry.py:143, in Geometry._epsilon(self)
141 use_mean_scale = rel is True or (rel is None and target is None)
142 if scale_eps is None and use_mean_scale:
--> 143 scale_eps = jax.lax.stop_gradient(self.mean_cost_matrix)
145 if isinstance(self._epsilon_init, epsilon_scheduler.Epsilon):
146 return self._epsilon_init.set(scale_epsilon=scale_eps)
File ~/Projects/ott/src/ott/geometry/geometry.py:123, in Geometry.mean_cost_matrix(self)
120 @property
121 def mean_cost_matrix(self) -> float:
122 """Mean of the :attr:`cost_matrix`."""
--> 123 tmp = self._masked_geom().apply_cost(self._n_normed_ones).squeeze()
124 return jnp.sum(tmp * self._m_normed_ones)
File ~/Projects/ott/src/ott/geometry/geometry.py:576, in Geometry.apply_cost(self, arr, axis, fn, **kwargs)
573 arr = arr.reshape(-1, 1)
575 app = functools.partial(self._apply_cost_to_vec, axis=axis, fn=fn, **kwargs)
--> 576 return jax.vmap(app, in_axes=1, out_axes=1)(arr)
[... skipping hidden 3 frame]
File ~/Projects/ott/src/ott/geometry/geometry.py:596, in Geometry._apply_cost_to_vec(self, vec, axis, fn, **_)
578 def _apply_cost_to_vec(
579 self,
580 vec: jnp.ndarray,
(...)
583 **_: Any,
584 ) -> jnp.ndarray:
585 """Apply ``[num_a, num_b]`` fn(cost) (or transpose) to vector.
586
587 Args:
(...)
594 A jnp.ndarray corresponding to cost x vector
595 """
--> 596 matrix = self.cost_matrix.T if axis == 0 else self.cost_matrix
597 matrix = fn(matrix) if fn is not None else matrix
598 return jnp.dot(matrix, vec)
File ~/Projects/ott/src/ott/geometry/geometry.py:111, in Geometry.cost_matrix(self)
109 cost = -jnp.log(self._kernel_matrix + eps)
110 cost *= self.inv_scale_cost
--> 111 return cost if self._epsilon_init is None else self.epsilon * cost
112 return self._cost_matrix * self.inv_scale_cost
File ~/Projects/ott/src/ott/geometry/geometry.py:155, in Geometry.epsilon(self)
152 @property
153 def epsilon(self) -> float:
154 """Epsilon regularization value."""
--> 155 return self._epsilon.target
File ~/Projects/ott/src/ott/geometry/geometry.py:143, in Geometry._epsilon(self)
141 use_mean_scale = rel is True or (rel is None and target is None)
142 if scale_eps is None and use_mean_scale:
--> 143 scale_eps = jax.lax.stop_gradient(self.mean_cost_matrix)
145 if isinstance(self._epsilon_init, epsilon_scheduler.Epsilon):
146 return self._epsilon_init.set(scale_epsilon=scale_eps)
File ~/Projects/ott/src/ott/geometry/geometry.py:123, in Geometry.mean_cost_matrix(self)
120 @property
121 def mean_cost_matrix(self) -> float:
122 """Mean of the :attr:`cost_matrix`."""
--> 123 tmp = self._masked_geom().apply_cost(self._n_normed_ones).squeeze()
124 return jnp.sum(tmp * self._m_normed_ones)
File ~/Projects/ott/src/ott/geometry/geometry.py:576, in Geometry.apply_cost(self, arr, axis, fn, **kwargs)
573 arr = arr.reshape(-1, 1)
575 app = functools.partial(self._apply_cost_to_vec, axis=axis, fn=fn, **kwargs)
--> 576 return jax.vmap(app, in_axes=1, out_axes=1)(arr)
[... skipping hidden 3 frame]
File ~/Projects/ott/src/ott/geometry/geometry.py:596, in Geometry._apply_cost_to_vec(self, vec, axis, fn, **_)
578 def _apply_cost_to_vec(
579 self,
580 vec: jnp.ndarray,
(...)
583 **_: Any,
584 ) -> jnp.ndarray:
585 """Apply ``[num_a, num_b]`` fn(cost) (or transpose) to vector.
586
587 Args:
(...)
594 A jnp.ndarray corresponding to cost x vector
595 """
--> 596 matrix = self.cost_matrix.T if axis == 0 else self.cost_matrix
597 matrix = fn(matrix) if fn is not None else matrix
598 return jnp.dot(matrix, vec)
[... skipping similar frames: Geometry._epsilon at line 143 (294 times), Geometry.cost_matrix at line 111 (294 times), Geometry.epsilon at line 155 (294 times), Geometry.mean_cost_matrix at line 123 (294 times), Geometry._apply_cost_to_vec at line 596 (293 times), Geometry.apply_cost at line 576 (293 times), WrappedFun.call_wrapped at line 165 (293 times), api_boundary.<locals>.reraise_with_filtered_traceback at line 166 (293 times), vmap.<locals>.vmap_f at line 1773 (293 times)]
File ~/Projects/ott/src/ott/geometry/geometry.py:576, in Geometry.apply_cost(self, arr, axis, fn, **kwargs)
573 arr = arr.reshape(-1, 1)
575 app = functools.partial(self._apply_cost_to_vec, axis=axis, fn=fn, **kwargs)
--> 576 return jax.vmap(app, in_axes=1, out_axes=1)(arr)
[... skipping hidden 3 frame]
File ~/Projects/ott/src/ott/geometry/geometry.py:596, in Geometry._apply_cost_to_vec(self, vec, axis, fn, **_)
578 def _apply_cost_to_vec(
579 self,
580 vec: jnp.ndarray,
(...)
583 **_: Any,
584 ) -> jnp.ndarray:
585 """Apply ``[num_a, num_b]`` fn(cost) (or transpose) to vector.
586
587 Args:
(...)
594 A jnp.ndarray corresponding to cost x vector
595 """
--> 596 matrix = self.cost_matrix.T if axis == 0 else self.cost_matrix
597 matrix = fn(matrix) if fn is not None else matrix
598 return jnp.dot(matrix, vec)
File ~/Projects/ott/src/ott/geometry/geometry.py:111, in Geometry.cost_matrix(self)
109 cost = -jnp.log(self._kernel_matrix + eps)
110 cost *= self.inv_scale_cost
--> 111 return cost if self._epsilon_init is None else self.epsilon * cost
112 return self._cost_matrix * self.inv_scale_cost
File ~/Projects/ott/src/ott/geometry/geometry.py:155, in Geometry.epsilon(self)
152 @property
153 def epsilon(self) -> float:
154 """Epsilon regularization value."""
--> 155 return self._epsilon.target
File ~/Projects/ott/src/ott/geometry/geometry.py:143, in Geometry._epsilon(self)
141 use_mean_scale = rel is True or (rel is None and target is None)
142 if scale_eps is None and use_mean_scale:
--> 143 scale_eps = jax.lax.stop_gradient(self.mean_cost_matrix)
145 if isinstance(self._epsilon_init, epsilon_scheduler.Epsilon):
146 return self._epsilon_init.set(scale_epsilon=scale_eps)
File ~/Projects/ott/src/ott/geometry/geometry.py:123, in Geometry.mean_cost_matrix(self)
120 @property
121 def mean_cost_matrix(self) -> float:
122 """Mean of the :attr:`cost_matrix`."""
--> 123 tmp = self._masked_geom().apply_cost(self._n_normed_ones).squeeze()
124 return jnp.sum(tmp * self._m_normed_ones)
File ~/Projects/ott/src/ott/geometry/geometry.py:862, in Geometry._n_normed_ones(self)
860 """Normalized array of shape ``[num_a,]``."""
861 mask = self.src_mask
--> 862 arr = jnp.ones(self.shape[0]) if mask is None else mask
863 return arr / jnp.sum(arr)
File ~/.mambaforge/envs/ott/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:2150, in ones(shape, dtype)
2148 shape = canonicalize_shape(shape)
2149 dtypes.check_user_dtype_supported(dtype, "ones")
-> 2150 return lax.full(shape, 1, _jnp_dtype(dtype))
[... skipping hidden 17 frame]
File ~/.mambaforge/envs/ott/lib/python3.10/site-packages/jax/_src/config.py:241, in Config.define_bool_state.<locals>.get_state(self)
240 def get_state(self):
--> 241 val = _thread_local_state.__dict__.get(name, unset)
242 return val if val is not unset else self._read(name)